From dd2325d9a7de7bef9a6bc2f0d5f26e605545b52d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 25 Jan 2016 11:52:26 -0800 Subject: [PATCH 001/131] [SPARK-11965][ML][DOC] Update user guide for RFormula feature interactions Update user guide for RFormula feature interactions. Meanwhile we also update other new features such as supporting string label in Spark 1.6. Author: Yanbo Liang Closes #10222 from yanboliang/spark-11965. --- docs/ml-features.md | 20 +++++++++++++++++- .../apache/spark/ml/feature/RFormula.scala | 21 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 677e4bfb916e..5809f65d637e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1121,7 +1121,25 @@ for more details on the API. ## RFormula -`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). +Currently we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. +The basic operators are: + +* `~` separate target and terms +* `+` concat terms, "+ 0" means removing intercept +* `-` remove a term, "- 1" means removing intercept +* `:` interaction (multiplication for numeric values, or binarized categorical values) +* `.` all columns except target + +Suppose `a` and `b` are double columns, we use the following simple examples to illustrate the effect of `RFormula`: + +* `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2` are coefficients. +* `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` are coefficients. + +`RFormula` produces a vector column of features and a double or string column of label. +Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. +If the label column is of type string, it will be first transformed to double with `StringIndexer`. +If the label column does not exist in the DataFrame, the output label column will be created from the specified response variable in the formula. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 6cc9d025445c..c21da218b36d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -45,6 +45,27 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { * Implements the transforms required for fitting a dataset against an R model formula. Currently * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * + * The basic operators are: + * - `~` separate target and terms + * - `+` concat terms, "+ 0" means removing intercept + * - `-` remove a term, "- 1" means removing intercept + * - `:` interaction (multiplication for numeric values, or binarized categorical values) + * - `.` all columns except target + * + * Suppose `a` and `b` are double columns, we use the following simple examples + * to illustrate the effect of `RFormula`: + * - `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2` + * are coefficients. + * - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` + * are coefficients. + * + * RFormula produces a vector column of features and a double or string column of label. + * Like when formulas are used in R for linear regression, string input columns will be one-hot + * encoded, and numeric columns will be cast to doubles. + * If the label column is of type string, it will be first transformed to double with + * `StringIndexer`. If the label column does not exist in the DataFrame, the output label column + * will be created from the specified response variable in the formula. */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { From ef8fb3612c7be1ac9058750be39ee28d88a148b4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Jan 2016 12:36:53 -0800 Subject: [PATCH 002/131] Closes #10879 Closes #9046 Closes #8532 Closes #10756 Closes #8960 Closes #10485 Closes #10467 From c037d25482ea63430fb42bfd86124c268be5a4a4 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Mon, 25 Jan 2016 14:42:44 -0600 Subject: [PATCH 003/131] [SPARK-12149][WEB UI] Executor UI improvement suggestions - Color UI Added color coding to the Executors page for Active Tasks, Failed Tasks, Completed Tasks and Task Time. Active Tasks is shaded blue with it's range based on percentage of total cores used. Failed Tasks is shaded red ranging over the first 10% of total tasks failed Completed Tasks is shaded green ranging over 10% of total tasks including failed and active tasks, but only when there are active or failed tasks on that executor. Task Time is shaded red when GC Time goes over 10% of total time with it's range directly corresponding to the percent of total time. Author: Alex Bozarth Closes #10154 from ajbozarth/spark12149. --- .../org/apache/spark/status/api/v1/api.scala | 2 + .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/ToolTips.scala | 3 + .../apache/spark/ui/exec/ExecutorsPage.scala | 98 +++++++++++++++---- .../apache/spark/ui/exec/ExecutorsTab.scala | 10 +- .../executor_list_json_expectation.json | 2 + project/MimaExcludes.scala | 6 ++ 7 files changed, 103 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index fe372116f1b6..3adf5b1109af 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -55,11 +55,13 @@ class ExecutorSummary private[spark]( val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, + val maxTasks: Int, val activeTasks: Int, val failedTasks: Int, val completedTasks: Int, val totalTasks: Int, val totalDuration: Long, + val totalGCTime: Long, val totalInputBytes: Long, val totalShuffleRead: Long, val totalShuffleWrite: Long, diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index eb53aa8e23ae..cf45414c4f78 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -195,7 +195,7 @@ private[spark] object SparkUI { val environmentListener = new EnvironmentListener val storageStatusListener = new StorageStatusListener - val executorsListener = new ExecutorsListener(storageStatusListener) + val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index cb122eaed83d..2d2d80be4aab 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -87,4 +87,7 @@ private[spark] object ToolTips { multiple operations (e.g. two map() functions) if they can be pipelined. Some operations also create multiple RDDs internally. Cached RDDs are shown in green. """ + + val TASK_TIME = + "Shaded red when garbage collection (GC) time is over 10% of task time" } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 440dfa267956..e36b96b3e697 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -50,6 +50,8 @@ private[ui] class ExecutorsPage( threadDumpEnabled: Boolean) extends WebUIPage("") { private val listener = parent.listener + // When GCTimePercent is edited change ToolTips.TASK_TIME to match + private val GCTimePercent = 0.1 def render(request: HttpServletRequest): Seq[Node] = { val (storageStatusList, execInfo) = listener.synchronized { @@ -77,7 +79,7 @@ private[ui] class ExecutorsPage( Failed Tasks Complete Tasks Total Tasks - Task Time + Task Time (GC Time) Input Shuffle Read @@ -129,13 +131,8 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} - {info.activeTasks} - {info.failedTasks} - {info.completedTasks} - {info.totalTasks} - - {Utils.msDurationToString(info.totalDuration)} - + {taskData(info.maxTasks, info.activeTasks, info.failedTasks, info.completedTasks, + info.totalTasks, info.totalDuration, info.totalGCTime)} {Utils.bytesToString(info.totalInputBytes)} @@ -177,7 +174,6 @@ private[ui] class ExecutorsPage( val maximumMemory = execInfo.map(_.maxMemory).sum val memoryUsed = execInfo.map(_.memoryUsed).sum val diskUsed = execInfo.map(_.diskUsed).sum - val totalDuration = execInfo.map(_.totalDuration).sum val totalInputBytes = execInfo.map(_.totalInputBytes).sum val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum @@ -192,13 +188,13 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} - {execInfo.map(_.activeTasks).sum} - {execInfo.map(_.failedTasks).sum} - {execInfo.map(_.completedTasks).sum} - {execInfo.map(_.totalTasks).sum} - - {Utils.msDurationToString(totalDuration)} - + {taskData(execInfo.map(_.maxTasks).sum, + execInfo.map(_.activeTasks).sum, + execInfo.map(_.failedTasks).sum, + execInfo.map(_.completedTasks).sum, + execInfo.map(_.totalTasks).sum, + execInfo.map(_.totalDuration).sum, + execInfo.map(_.totalGCTime).sum)} {Utils.bytesToString(totalInputBytes)} @@ -219,7 +215,7 @@ private[ui] class ExecutorsPage( Failed Tasks Complete Tasks Total Tasks - Task Time + Task Time (GC Time) Input Shuffle Read @@ -233,6 +229,70 @@ private[ui] class ExecutorsPage( } + + private def taskData( + maxTasks: Int, + activeTasks: Int, + failedTasks: Int, + completedTasks: Int, + totalTasks: Int, + totalDuration: Long, + totalGCTime: Long): + Seq[Node] = { + // Determine Color Opacity from 0.5-1 + // activeTasks range from 0 to maxTasks + val activeTasksAlpha = + if (maxTasks > 0) { + (activeTasks.toDouble / maxTasks) * 0.5 + 0.5 + } else { + 1 + } + // failedTasks range max at 10% failure, alpha max = 1 + val failedTasksAlpha = + if (totalTasks > 0) { + math.min(10 * failedTasks.toDouble / totalTasks, 1) * 0.5 + 0.5 + } else { + 1 + } + // totalDuration range from 0 to 50% GC time, alpha max = 1 + val totalDurationAlpha = + if (totalDuration > 0) { + math.min(totalGCTime.toDouble / totalDuration + 0.5, 1) + } else { + 1 + } + + val tableData = + 0) { + "background:hsla(240, 100%, 50%, " + activeTasksAlpha + ");color:white" + } else { + "" + } + }>{activeTasks} + 0) { + "background:hsla(0, 100%, 50%, " + failedTasksAlpha + ");color:white" + } else { + "" + } + }>{failedTasks} + {completedTasks} + {totalTasks} + GCTimePercent * totalDuration) { + "background:hsla(0, 100%, 50%, " + totalDurationAlpha + ");color:white" + } else { + "" + } + }> + {Utils.msDurationToString(totalDuration)} + ({Utils.msDurationToString(totalGCTime)}) + ; + + tableData + } } private[spark] object ExecutorsPage { @@ -245,11 +305,13 @@ private[spark] object ExecutorsPage { val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed + val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0) val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0) val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks val totalDuration = listener.executorToDuration.getOrElse(execId, 0L) + val totalGCTime = listener.executorToJvmGCTime.getOrElse(execId, 0L) val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L) val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L) val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L) @@ -261,11 +323,13 @@ private[spark] object ExecutorsPage { rddBlocks, memUsed, diskUsed, + maxTasks, activeTasks, failedTasks, completedTasks, totalTasks, totalDuration, + totalGCTime, totalInputBytes, totalShuffleRead, totalShuffleWrite, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 160d7a4dff2d..a9e926b15878 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{ExceptionFailure, Resubmitted, SparkContext} +import org.apache.spark.{ExceptionFailure, Resubmitted, SparkConf, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} @@ -43,11 +43,14 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi -class ExecutorsListener(storageStatusListener: StorageStatusListener) extends SparkListener { +class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) + extends SparkListener { + val executorToTasksMax = HashMap[String, Int]() val executorToTasksActive = HashMap[String, Int]() val executorToTasksComplete = HashMap[String, Int]() val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() + val executorToJvmGCTime = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() val executorToInputRecords = HashMap[String, Long]() val executorToOutputBytes = HashMap[String, Long]() @@ -62,6 +65,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap + executorToTasksMax(eid) = + executorAdded.executorInfo.totalCores / conf.getInt("spark.task.cpus", 1) executorIdToData(eid) = ExecutorUIData(executorAdded.time) } @@ -131,6 +136,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp executorToShuffleWrite(eid) = executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.bytesWritten } + executorToJvmGCTime(eid) = executorToJvmGCTime.getOrElse(eid, 0L) + metrics.jvmGCTime } } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index cb622e147249..94f8aeac55b5 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -4,11 +4,13 @@ "rddBlocks" : 8, "memoryUsed" : 28000128, "diskUsed" : 0, + "maxTasks" : 0, "activeTasks" : 0, "failedTasks" : 1, "completedTasks" : 31, "totalTasks" : 32, "totalDuration" : 8820, + "totalGCTime" : 352, "totalInputBytes" : 28000288, "totalShuffleRead" : 0, "totalShuffleWrite" : 13180, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c65fae482c5c..501456b04317 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -127,6 +127,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") + ) ++ Seq( + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") ) ++ // SPARK-12665 Remove deprecated and unused classes Seq( @@ -301,6 +304,9 @@ object MimaExcludes { // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") + ) ++ Seq( + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") ) ++ // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a // private class. From 7d877c3439d872ec2a9e07d245e9c96174c0cf00 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 25 Jan 2016 12:44:20 -0800 Subject: [PATCH 004/131] [SPARK-12902] [SQL] visualization for generated operators This PR brings back visualization for generated operators, they looks like: ![sql](https://cloud.githubusercontent.com/assets/40902/12460920/0dc7956a-bf6b-11e5-9c3f-8389f452526e.png) ![stage](https://cloud.githubusercontent.com/assets/40902/12460923/11806ac4-bf6b-11e5-9c72-e84a62c5ea93.png) Note: SQL metrics are not supported right now, because they are very slow, will be supported once we have batch mode. Author: Davies Liu Closes #10828 from davies/viz_codegen. --- .../apache/spark/ui/static/spark-dag-viz.js | 2 +- .../spark/ui/scope/RDDOperationGraph.scala | 6 +- .../sql/execution/ui/static/spark-sql-viz.css | 6 ++ .../spark/sql/execution/SparkPlanInfo.scala | 9 +- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 2 +- .../sql/execution/ui/SparkPlanGraph.scala | 95 ++++++++++++++----- .../execution/metric/SQLMetricsSuite.scala | 10 +- .../sql/execution/ui/SQLListenerSuite.scala | 2 +- 9 files changed, 104 insertions(+), 32 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 83dbea40b63f..4337c42087e7 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -284,7 +284,7 @@ function renderDot(dot, container, forJob) { renderer(container, g); // Find the stage cluster and mark it for styling and post-processing - container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); + container.selectAll("g.cluster[name^=\"Stage \"]").classed("stage", true); } /* -------------------- * diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 06da74f1b6b5..003c218aada9 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -130,7 +130,11 @@ private[ui] object RDDOperationGraph extends Logging { } } // Attach the outermost cluster to the root cluster, and the RDD to the innermost cluster - rddClusters.headOption.foreach { cluster => rootCluster.attachChildCluster(cluster) } + rddClusters.headOption.foreach { cluster => + if (!rootCluster.childClusters.contains(cluster)) { + rootCluster.attachChildCluster(cluster) + } + } rddClusters.lastOption.foreach { cluster => cluster.attachChildNode(node) } } } diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index ddd3a91dd8ef..303f8ebb8814 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -20,6 +20,12 @@ text-shadow: none; } +#plan-viz-graph svg g.cluster rect { + fill: #A0DFFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + #plan-viz-graph svg g.node rect { fill: #C3EBFF; stroke: #3EC0FF; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 4f750ad13ab8..4dd992824419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -36,12 +36,17 @@ class SparkPlanInfo( private[sql] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val children = plan match { + case WholeStageCodegen(child, _) => child :: Nil + case InputAdapter(child) => child :: Nil + case plan => plan.children + } val metrics = plan.metrics.toSeq.map { case (key, metric) => new SQLMetricInfo(metric.name.getOrElse(key), metric.id, Utils.getFormattedClassName(metric.param)) } - val children = plan.children.map(fromSparkPlan) - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, plan.metadata, metrics) + new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + plan.metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index c74ad4040699..49915adf6cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -99,7 +99,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") } private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { - val metadata = graph.nodes.flatMap { node => + val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
{node.desc}
} @@ -110,7 +110,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution")
{graph.makeDotFile(metrics)}
-
{graph.nodes.size.toString}
+
{graph.allNodes.size.toString}
{metadata} {planVisualizationResources} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index cd5613692708..83c64f755f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -219,7 +219,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case SparkListenerSQLExecutionStart(executionId, description, details, physicalPlanDescription, sparkPlanInfo, time) => val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + val sqlPlanMetrics = physicalPlanGraph.allNodes.flatMap { node => node.metrics.map(metric => metric.accumulatorId -> metric) } val executionUIData = new SQLExecutionUIData( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 3a6eff939982..4eb248569b28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -41,6 +41,16 @@ private[ui] case class SparkPlanGraph( dotFile.append("}") dotFile.toString() } + + /** + * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. + */ + val allNodes: Seq[SparkPlanGraphNode] = { + nodes.flatMap { + case cluster: SparkPlanGraphCluster => cluster.nodes :+ cluster + case node => Seq(node) + } + } } private[sql] object SparkPlanGraph { @@ -52,7 +62,7 @@ private[sql] object SparkPlanGraph { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null) new SparkPlanGraph(nodes, edges) } @@ -60,22 +70,40 @@ private[sql] object SparkPlanGraph { planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], - edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + edges: mutable.ArrayBuffer[SparkPlanGraphEdge], + parent: SparkPlanGraphNode, + subgraph: SparkPlanGraphCluster): Unit = { + if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) { + val cluster = new SparkPlanGraphCluster( + nodeIdGenerator.getAndIncrement(), + planInfo.nodeName, + planInfo.simpleString, + mutable.ArrayBuffer[SparkPlanGraphNode]()) + nodes += cluster + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) + } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) { + buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) + } else { + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val node = new SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + if (subgraph == null) { + nodes += node + } else { + subgraph.nodes += node + } + + if (parent != null) { + edges += SparkPlanGraphEdge(node.id, parent.id) + } + planInfo.children.foreach( + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) } - val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, - planInfo.simpleString, planInfo.metadata, metrics) - - nodes += node - val childrenNodes = planInfo.children.map( - child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) - for (child <- childrenNodes) { - edges += SparkPlanGraphEdge(child.id, node.id) - } - node } } @@ -86,12 +114,12 @@ private[sql] object SparkPlanGraph { * @param name the name of this SparkPlan node * @param metrics metrics that this SparkPlan node will track */ -private[ui] case class SparkPlanGraphNode( - id: Long, - name: String, - desc: String, - metadata: Map[String, String], - metrics: Seq[SQLPlanMetric]) { +private[ui] class SparkPlanGraphNode( + val id: Long, + val name: String, + val desc: String, + val metadata: Map[String, String], + val metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { val builder = new mutable.StringBuilder(name) @@ -117,6 +145,27 @@ private[ui] case class SparkPlanGraphNode( } } +/** + * Represent a tree of SparkPlan for WholeStageCodegen. + */ +private[ui] class SparkPlanGraphCluster( + id: Long, + name: String, + desc: String, + val nodes: mutable.ArrayBuffer[SparkPlanGraphNode]) + extends SparkPlanGraphNode(id, name, desc, Map.empty, Nil) { + + override def makeDotNode(metricsValue: Map[Long, String]): String = { + s""" + | subgraph cluster${id} { + | label=${name}; + | ${nodes.map(_.makeDotNode(metricsValue)).mkString(" \n")} + | } + """.stripMargin + } +} + + /** * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child * node id. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 51285431a47e..cbae19ebd269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -86,7 +86,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( - df.queryExecution.executedPlan)).nodes.filter { node => + df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => @@ -134,6 +134,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("WholeStageCodegen metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) + // TODO: update metrics in generated operators + val df = sqlContext.range(10).filter('id < 5) + testSparkPlanMetrics(df, 1, Map.empty) + } + test("TungstenAggregate metrics") { // Assume the execution plan is // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index eef3c1f3e34d..81a159d542c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -83,7 +83,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val df = createTestDataFrame val accumulatorIds = SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) - .nodes.flatMap(_.metrics.map(_.accumulatorId)) + .allNodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => From 00026fa9912ecee5637f1e7dd222f977f31f6766 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 25 Jan 2016 12:59:11 -0800 Subject: [PATCH 005/131] [SPARK-12901][SQL][HOT-FIX] Fix scala 2.11 compilation. --- .../apache/spark/sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../spark/sql/execution/datasources/json/JSONOptions.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 5d0e99d7601d..709daccbbef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs private[sql] class CSVOptions( - @transient parameters: Map[String, String]) + @transient private val parameters: Map[String, String]) extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 0a083b5e3598..fe5b20697e40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient parameters: Map[String, String]) + @transient private val parameters: Map[String, String]) extends Serializable { val samplingRatio = From 9348431da212ec3ab7be2b8e89a952a48b4e2a31 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 25 Jan 2016 13:38:09 -0800 Subject: [PATCH 006/131] [SPARK-12975][SQL] Throwing Exception when Bucketing Columns are part of Partitioning Columns When users are using `partitionBy` and `bucketBy` at the same time, some bucketing columns might be part of partitioning columns. For example, ``` df.write .format(source) .partitionBy("i") .bucketBy(8, "i", "k") .saveAsTable("bucketed_table") ``` However, in the above case, adding column `i` into `bucketBy` is useless. It is just wasting extra CPU when reading or writing bucket tables. Thus, like Hive, we can issue an exception and let users do the change. Also added a test case for checking if the information of `sortBy` and `bucketBy` columns are correctly saved in the metastore table. Could you check if my understanding is correct? cloud-fan rxin marmbrus Thanks! Author: gatorsmile Closes #10891 from gatorsmile/commonKeysInPartitionByBucketBy. --- .../apache/spark/sql/DataFrameWriter.scala | 9 +++ .../sql/hive/MetastoreDataSourcesSuite.scala | 55 ++++++++++++++++++- .../sql/sources/BucketedWriteSuite.scala | 22 +++++++- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ab63fe4aa88b..12eb2393634a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -240,6 +240,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { n <- numBuckets } yield { require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + + // partitionBy columns cannot be used in bucketBy + if (normalizedParCols.nonEmpty && + normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) { + throw new AnalysisException( + s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " + + s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'") + } + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 253f13c59852..211932fea00e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -745,7 +745,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - test("Saving partition columns information") { + test("Saving partitionBy columns information") { val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") val tableName = s"partitionInfo_${System.currentTimeMillis()}" @@ -776,6 +776,59 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("Saving information for sortBy and bucketBy columns") { + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") + val tableName = s"bucketingInfo_${System.currentTimeMillis()}" + + withTable(tableName) { + df.write + .format("parquet") + .bucketBy(8, "d", "b") + .sortBy("c") + .saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val expectedSortByColumns = StructType(df.schema("c") :: Nil) + + val numBuckets = metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt + assert(numBuckets == 8) + + val numBucketCols = metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt + assert(numBucketCols == 2) + + val numSortCols = metastoreTable.properties("spark.sql.sources.schema.numSortCols").toInt + assert(numSortCols == 1) + + val actualBucketByColumns = + StructType( + (0 until numBucketCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.bucketCol.$index")) + }) + // Make sure bucketBy columns are correctly stored in metastore. + assert( + expectedBucketByColumns.sameType(actualBucketByColumns), + s"Partitions columns stored in metastore $actualBucketByColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedBucketByColumns.") + + val actualSortByColumns = + StructType( + (0 until numSortCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.sortCol.$index")) + }) + // Make sure sortBy columns are correctly stored in metastore. + assert( + expectedSortByColumns.sameType(actualSortByColumns), + s"Partitions columns stored in metastore $actualSortByColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedSortByColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } + } + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = { (from to to).map(i => i -> s"str$i").toDF("c1", "c2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 59b74d2b4c5e..a32f8fb4c5a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -92,10 +92,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle fail(s"Unable to find the related bucket files.") } + // Remove the duplicate columns in bucketCols and sortCols; + // Otherwise, we got analysis errors due to duplicate names + val selectedColumns = (bucketCols ++ sortCols).distinct // We may lose the type information after write(e.g. json format doesn't keep schema // information), here we get the types from the original dataframe. - val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) - val columns = (bucketCols ++ sortCols).zip(types).map { + val types = df.select(selectedColumns.map(col): _*).schema.map(_.dataType) + val columns = selectedColumns.zip(types).map { case (colName, dt) => col(colName).cast(dt) } @@ -158,6 +161,21 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } + test("write bucketed data with the overlapping bucketBy and partitionBy columns") { + intercept[AnalysisException](df.write + .partitionBy("i", "j") + .bucketBy(8, "j", "k") + .sortBy("k") + .saveAsTable("bucketed_table")) + } + + test("write bucketed data with the identical bucketBy and partitionBy columns") { + intercept[AnalysisException](df.write + .partitionBy("i") + .bucketBy(8, "i") + .saveAsTable("bucketed_table")) + } + test("write bucketed data without partitionBy") { for (source <- Seq("parquet", "json", "orc")) { withTable("bucketed_table") { From dcae355c64d7f6fdf61df2feefe464eb96c4cf5e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 25 Jan 2016 13:54:21 -0800 Subject: [PATCH 007/131] [SPARK-12905][ML][PYSPARK] PCAModel return eigenvalues for PySpark ```PCAModel``` can output ```explainedVariance``` at Python side. cc mengxr srowen Author: Yanbo Liang Closes #10830 from yanboliang/spark-12905. --- .../main/scala/org/apache/spark/ml/feature/PCA.scala | 2 ++ python/pyspark/ml/feature.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 7020397f3b06..0e07dfabfeaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -102,6 +102,8 @@ object PCA extends DefaultParamsReadable[PCA] { * Model fitted by [[PCA]]. * * @param pc A principal components Matrix. Each column is one principal component. + * @param explainedVariance A vector of proportions of variance explained by + * each principal component. */ @Experimental class PCAModel private[ml] ( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 141ec3492aa9..1fa0eab384e7 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1987,6 +1987,8 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): >>> model = pca.fit(df) >>> model.transform(df).collect()[0].pca_features DenseVector([1.648..., -4.013...]) + >>> model.explainedVariance + DenseVector([0.794..., 0.205...]) .. versionadded:: 1.5.0 """ @@ -2052,6 +2054,15 @@ def pc(self): """ return self._call_java("pc") + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns a vector of proportions of variance + explained by each principal component. + """ + return self._call_java("explainedVariance") + @inherit_doc class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): From 6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 25 Jan 2016 15:05:05 -0800 Subject: [PATCH 008/131] [SPARK-12934][SQL] Count-min sketch serialization This PR adds serialization support for `CountMinSketch`. A version number is added to version the serialized binary format. Author: Cheng Lian Closes #10893 from liancheng/cms-serialization. --- .../spark/util/sketch/CountMinSketch.java | 32 ++++- .../spark/util/sketch/CountMinSketchImpl.java | 129 ++++++++++++++++-- .../sketch/IncompatibleMergeException.java | 24 ++++ .../util/sketch/CountMinSketchSuite.scala | 47 ++++++- 4 files changed, 213 insertions(+), 19 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 21b161bc74ae..67938644d9f6 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -17,6 +17,7 @@ package org.apache.spark.util.sketch; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -54,6 +55,25 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { + /** + * Version number of the serialized binary format. + */ + public enum Version { + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + public int getVersionNumber() { + return versionNumber; + } + } + + public abstract Version version(); + /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -99,19 +119,23 @@ abstract public class CountMinSketch { * * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed * can be merged. + * + * @exception IncompatibleMergeException if the {@code other} {@link CountMinSketch} has + * incompatible depth, width, relative-error, confidence, or random seed. */ - public abstract CountMinSketch mergeInPlace(CountMinSketch other); + public abstract CountMinSketch mergeInPlace(CountMinSketch other) + throws IncompatibleMergeException; /** * Writes out this {@link CountMinSketch} to an output stream in binary format. */ - public abstract void writeTo(OutputStream out); + public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. */ - public static CountMinSketch readFrom(InputStream in) { - throw new UnsupportedOperationException("Not implemented yet"); + public static CountMinSketch readFrom(InputStream in) throws IOException { + return CountMinSketchImpl.readFrom(in); } /** diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e9fdbe3a8686..0209446ea3b1 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -17,11 +17,30 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; +/* + * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian + * order): + * + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; @@ -33,7 +52,7 @@ class CountMinSketchImpl extends CountMinSketch { private double eps; private double confidence; - public CountMinSketchImpl(int depth, int width, int seed) { + CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; this.width = width; this.eps = 2.0 / width; @@ -41,7 +60,7 @@ public CountMinSketchImpl(int depth, int width, int seed) { initTablesWith(depth, width, seed); } - public CountMinSketchImpl(double eps, double confidence, int seed) { + CountMinSketchImpl(double eps, double confidence, int seed) { // 2/w = eps ; w = 2/eps // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) this.eps = eps; @@ -51,6 +70,53 @@ public CountMinSketchImpl(double eps, double confidence, int seed) { initTablesWith(depth, width, seed); } + CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { + this.depth = depth; + this.width = width; + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + this.hashA = hashA; + this.table = table; + this.totalCount = totalCount; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof CountMinSketchImpl)) { + return false; + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + return + this.depth == that.depth && + this.width == that.width && + this.totalCount == that.totalCount && + Arrays.equals(this.hashA, that.hashA) && + Arrays.deepEquals(this.table, that.table); + } + + @Override + public int hashCode() { + int hash = depth; + + hash = hash * 31 + width; + hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32)); + hash = hash * 31 + Arrays.hashCode(hashA); + hash = hash * 31 + Arrays.deepHashCode(table); + + return hash; + } + + @Override + public Version version() { + return Version.V1; + } + private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -221,27 +287,29 @@ private long estimateCountForStringItem(String item) { } @Override - public CountMinSketch mergeInPlace(CountMinSketch other) { + public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException { if (other == null) { - throw new CMSMergeException("Cannot merge null estimator"); + throw new IncompatibleMergeException("Cannot merge null estimator"); } if (!(other instanceof CountMinSketchImpl)) { - throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName()); + throw new IncompatibleMergeException( + "Cannot merge estimator of class " + other.getClass().getName() + ); } CountMinSketchImpl that = (CountMinSketchImpl) other; if (this.depth != that.depth) { - throw new CMSMergeException("Cannot merge estimators of different depth"); + throw new IncompatibleMergeException("Cannot merge estimators of different depth"); } if (this.width != that.width) { - throw new CMSMergeException("Cannot merge estimators of different width"); + throw new IncompatibleMergeException("Cannot merge estimators of different width"); } if (!Arrays.equals(this.hashA, that.hashA)) { - throw new CMSMergeException("Cannot merge estimators of different seed"); + throw new IncompatibleMergeException("Cannot merge estimators of different seed"); } for (int i = 0; i < this.table.length; ++i) { @@ -256,13 +324,48 @@ public CountMinSketch mergeInPlace(CountMinSketch other) { } @Override - public void writeTo(OutputStream out) { - throw new UnsupportedOperationException("Not implemented yet"); + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(version().getVersionNumber()); + + dos.writeLong(this.totalCount); + dos.writeInt(this.depth); + dos.writeInt(this.width); + + for (int i = 0; i < this.depth; ++i) { + dos.writeLong(this.hashA[i]); + } + + for (int i = 0; i < this.depth; ++i) { + for (int j = 0; j < this.width; ++j) { + dos.writeLong(table[i][j]); + } + } } - protected static class CMSMergeException extends RuntimeException { - public CMSMergeException(String message) { - super(message); + public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + // Ignores version number + dis.readInt(); + + long totalCount = dis.readLong(); + int depth = dis.readInt(); + int width = dis.readInt(); + + long hashA[] = new long[depth]; + for (int i = 0; i < depth; ++i) { + hashA[i] = dis.readLong(); + } + + long table[][] = new long[depth][width]; + for (int i = 0; i < depth; ++i) { + for (int j = 0; j < width; ++j) { + table[i][j] = dis.readLong(); + } } + + return new CountMinSketchImpl(depth, width, totalCount, hashA, table); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java new file mode 100644 index 000000000000..64b567caa57c --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +public class IncompatibleMergeException extends Exception { + public IncompatibleMergeException(String message) { + super(message); + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index ec5b4eddeca0..b9c7f5c23a8f 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite private val seed = 42 + // Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(sketch: CountMinSketch): Unit = { + val out = new ByteArrayOutputStream() + sketch.writeTo(out) + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = CountMinSketch.readFrom(in) + + assert(sketch === deserialized) + } + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { test(s"accuracy - $typeName") { - val r = new Random() + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) val numAllItems = 1000000 val allItems = Array.fill(numAllItems)(itemGenerator(r)) @@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + checkSerDe(sketch) val probCorrect = { val numErrors = allItems.map { item => @@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { test(s"mergeInPlace - $typeName") { - val r = new Random() + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + val numToMerge = 5 val numItemsPerSketch = 100000 val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { @@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val sketches = perSketchItems.map { items => val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + items.foreach(sketch.add) + checkSerDe(sketch) + sketch } val mergedSketch = sketches.reduce(_ mergeInPlace _) + checkSerDe(mergedSketch) val expectedSketch = { val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[Long]("Long") { _.nextLong() } testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + CountMinSketch.create(10, 10, 1).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 20, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 10, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + } } From be375fcbd200fb0e210b8edcfceb5a1bcdbba94b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 25 Jan 2016 16:23:59 -0800 Subject: [PATCH 009/131] [SPARK-12879] [SQL] improve the unsafe row writing framework As we begin to use unsafe row writing framework(`BufferHolder` and `UnsafeRowWriter`) in more and more places(`UnsafeProjection`, `UnsafeRowParquetRecordReader`, `GenerateColumnAccessor`, etc.), we should add more doc to it and make it easier to use. This PR abstract the technique used in `UnsafeRowParquetRecordReader`: avoid unnecessary operatition as more as possible. For example, do not always point the row to the buffer at the end, we only need to update the size of row. If all fields are of primitive type, we can even save the row size updating. Then we can apply this technique to more places easily. a local benchmark shows `UnsafeProjection` is up to 1.7x faster after this PR: **old version** ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- single long 2616.04 102.61 1.00 X single nullable long 3032.54 88.52 0.86 X primitive types 9121.05 29.43 0.29 X nullable primitive types 12410.60 21.63 0.21 X ``` **new version** ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- single long 1533.34 175.07 1.00 X single nullable long 2306.73 116.37 0.66 X primitive types 8403.93 31.94 0.18 X nullable primitive types 12448.39 21.56 0.12 X ``` For single non-nullable long(the best case), we can have about 1.7x speed up. Even it's nullable, we can still have 1.3x speed up. For other cases, it's not such a boost as the saved operations only take a little proportion of the whole process. The benchmark code is included in this PR. Author: Wenchen Fan Closes #10809 from cloud-fan/unsafe-projection. --- .../expressions/codegen/BufferHolder.java | 44 +++--- .../expressions/codegen/UnsafeRowWriter.java | 58 +++++--- .../codegen/GenerateUnsafeProjection.scala | 66 ++++++--- .../spark/sql/UnsafeProjectionBenchmark.scala | 136 ++++++++++++++++++ .../parquet/UnsafeRowParquetRecordReader.java | 17 +-- .../columnar/GenerateColumnAccessor.scala | 8 +- .../datasources/text/DefaultSource.scala | 7 +- 7 files changed, 258 insertions(+), 78 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index d26b1b187c27..af61e2011f40 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -21,24 +21,40 @@ import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer when construct unsafe rows. + * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and + * automatically re-point the unsafe row to it. + * + * This class can be used to build a one-pass unsafe row writing program, i.e. data will be written + * to the data buffer directly and no extra copy is needed. There should be only one instance of + * this class per writing program, so that the memory segment/data buffer can be reused. Note that + * for each incoming record, we should call `reset` of BufferHolder instance before write the record + * and reuse the data buffer. + * + * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update + * the size of the result row, after writing a record to the buffer. However, we can skip this step + * if the fields of row are all fixed-length, as the size of result row is also fixed. */ public class BufferHolder { public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; + private final UnsafeRow row; + private final int fixedSize; - public BufferHolder() { - this(64); + public BufferHolder(UnsafeRow row) { + this(row, 64); } - public BufferHolder(int size) { - buffer = new byte[size]; + public BufferHolder(UnsafeRow row, int initialSize) { + this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields(); + this.buffer = new byte[fixedSize + initialSize]; + this.row = row; + this.row.pointTo(buffer, buffer.length); } /** - * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize, UnsafeRow row) { + public void grow(int neededSize) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -50,22 +66,12 @@ public void grow(int neededSize, UnsafeRow row) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; - if (row != null) { - row.pointTo(buffer, length * 2); - } + row.pointTo(buffer, buffer.length); } } - public void grow(int neededSize) { - grow(neededSize, null); - } - public void reset() { - cursor = Platform.BYTE_ARRAY_OFFSET; - } - public void resetTo(int offset) { - assert(offset <= buffer.length); - cursor = Platform.BYTE_ARRAY_OFFSET + offset; + cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } public int totalSize() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index e227c0dec974..477661704387 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -26,38 +26,56 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A helper class to write data into global row buffer using `UnsafeRow` format, - * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + * A helper class to write data into global row buffer using `UnsafeRow` format. + * + * It will remember the offset of row buffer which it starts to write, and move the cursor of row + * buffer while writing. If new data(can be the input record if this is the outermost writer, or + * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be + * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * `startingOffset` and clear out null bits. + * + * Note that if this is the outermost writer, which means we will always write from the very + * beginning of the global row buffer, we don't need to update `startingOffset` and can just call + * `zeroOutNullBytes` before writing new data. */ public class UnsafeRowWriter { - private BufferHolder holder; + private final BufferHolder holder; // The offset of the global buffer where we start to write this row. private int startingOffset; - private int nullBitsSize; - private UnsafeRow row; + private final int nullBitsSize; + private final int fixedSize; - public void initialize(BufferHolder holder, int numFields) { + public UnsafeRowWriter(BufferHolder holder, int numFields) { this.holder = holder; - this.startingOffset = holder.cursor; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + this.fixedSize = nullBitsSize + 8 * numFields; + this.startingOffset = holder.cursor; + } + + /** + * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null + * bits. This should be called before we write a new nested struct to the row buffer. + */ + public void reset() { + this.startingOffset = holder.cursor; // grow the global buffer to make sure it has enough space to write fixed-length data. - final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize, row); + holder.grow(fixedSize); holder.cursor += fixedSize; - // zero-out the null bits region + zeroOutNullBytes(); + } + + /** + * Clears out null bits. This should be called before we write a new row to row buffer. + */ + public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { Platform.putLong(holder.buffer, startingOffset + i, 0L); } } - public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { - initialize(holder, numFields); - this.row = row; - } - private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); @@ -98,7 +116,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes, row); + holder.grow(paddingBytes); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -161,7 +179,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16, row); + holder.grow(16); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -193,7 +211,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize, row); + holder.grow(roundedSize); zeroOutPaddingBytes(numBytes); @@ -214,7 +232,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize, row); + holder.grow(roundedSize); zeroOutPaddingBytes(numBytes); @@ -230,7 +248,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16, row); + holder.grow(16); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 72bf39a0398b..6aa9cbf08bdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - private val rowWriterClass = classOf[UnsafeRowWriter].getName - private val arrayWriterClass = classOf[UnsafeArrayWriter].getName - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, @@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String): String = { + bufferHolder: String, + isTopLevel: Boolean = false): String = { + val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") + ctx.addMutableState(rowWriterClass, rowWriter, + s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + + val resetWriter = if (isTopLevel) { + // For top level row writer, it always writes to the beginning of the global buffer holder, + // which means its fixed-size region always in the same position, so we don't need to call + // `reset` to set up its fixed-size region every time. + if (inputs.map(_.isNull).forall(_ == "false")) { + // If all fields are not nullable, which means the null bits never changes, then we don't + // need to clear it out every time. + "" + } else { + s"$rowWriter.zeroOutNullBytes();" + } + } else { + s"$rowWriter.reset();" + } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => @@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ - case _ if ctx.isPrimitiveType(dt) => - s""" - $rowWriter.write($index, ${input.value}); - """ - case t: DecimalType => s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" @@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } s""" - $rowWriter.initialize($bufferHolder, ${inputs.length}); + $resetWriter ${ctx.splitExpressions(row, writeFields)} """.trim } @@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, elementType: DataType, bufferHolder: String): String = { + val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"this.$arrayWriter = new $arrayWriterClass();") @@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) + val numVarLenFields = exprTypes.count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + val result = ctx.freshName("result") ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") - val bufferHolder = ctx.freshName("bufferHolder") + + val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + ctx.addMutableState(holderClass, holder, + s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + + val resetBufferHolder = if (numVarLenFields == 0) { + "" + } else { + s"$holder.reset();" + } + val updateRowSize = if (numVarLenFields == 0) { + "" + } else { + s"$result.setTotalSize($holder.totalSize());" + } // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val writeExpressions = + writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val code = s""" - $bufferHolder.reset(); + $resetBufferHolder $evalSubexpr - ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - - $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); + $writeExpressions + $updateRowSize """ ExprCode(code, "false", result) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala new file mode 100644 index 000000000000..a6d90409382e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + */ +object UnsafeProjectionBenchmark { + + def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray + } + + def main(args: Array[String]) { + val iters = 1024 * 16 + val numRows = 1024 * 16 + + val benchmark = new Benchmark("unsafe projection", iters * numRows) + + + val schema1 = new StructType().add("l", LongType, false) + val attrs1 = schema1.toAttributes + val rows1 = generateRows(schema1, numRows) + val projection1 = UnsafeProjection.create(attrs1, attrs1) + + benchmark.addCase("single long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection1(rows1(i)).getLong(0) + i += 1 + } + } + } + + val schema2 = new StructType().add("l", LongType, true) + val attrs2 = schema2.toAttributes + val rows2 = generateRows(schema2, numRows) + val projection2 = UnsafeProjection.create(attrs2, attrs2) + + benchmark.addCase("single nullable long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection2(rows2(i)).getLong(0) + i += 1 + } + } + } + + + val schema3 = new StructType() + .add("boolean", BooleanType, false) + .add("byte", ByteType, false) + .add("short", ShortType, false) + .add("int", IntegerType, false) + .add("long", LongType, false) + .add("float", FloatType, false) + .add("double", DoubleType, false) + val attrs3 = schema3.toAttributes + val rows3 = generateRows(schema3, numRows) + val projection3 = UnsafeProjection.create(attrs3, attrs3) + + benchmark.addCase("7 primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection3(rows3(i)).getLong(0) + i += 1 + } + } + } + + + val schema4 = new StructType() + .add("boolean", BooleanType, true) + .add("byte", ByteType, true) + .add("short", ShortType, true) + .add("int", IntegerType, true) + .add("long", LongType, true) + .add("float", FloatType, true) + .add("double", DoubleType, true) + val attrs4 = schema4.toAttributes + val rows4 = generateRows(schema4, numRows) + val projection4 = UnsafeProjection.create(attrs4, attrs4) + + benchmark.addCase("7 nullable primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection4(rows4(i)).getLong(0) + i += 1 + } + } + } + + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + single long 1533.34 175.07 1.00 X + single nullable long 2306.73 116.37 0.66 X + primitive types 8403.93 31.94 0.18 X + nullable primitive types 12448.39 21.56 0.12 X + */ + benchmark.run() + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 80805f15a8f0..17adfec32192 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -73,11 +73,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private boolean containsVarLenFields; - /** - * The number of bytes in the fixed length portion of the row. - */ - private int fixedSizeBytes; - /** * For each request column, the reader to read this column. * columnsReaders[i] populated the UnsafeRow's attribute at i. @@ -266,19 +261,13 @@ private void initializeInternal() throws IOException { /** * Initialize rows and rowWriters. These objects are reused across all rows in the relation. */ - int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); - rowByteSize += 8 * requestedSchema.getFieldCount(); - fixedSizeBytes = rowByteSize; - rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; containsVarLenFields = numVarLenFields > 0; rowWriters = new UnsafeRowWriter[rows.length]; for (int i = 0; i < rows.length; ++i) { rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); - rowWriters[i] = new UnsafeRowWriter(); - BufferHolder holder = new BufferHolder(rowByteSize); - rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); - rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length); + BufferHolder holder = new BufferHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE); + rowWriters[i] = new UnsafeRowWriter(holder, requestedSchema.getFieldCount()); } } @@ -295,7 +284,7 @@ private boolean loadBatch() throws IOException { if (containsVarLenFields) { for (int i = 0; i < rowWriters.length; ++i) { - rowWriters[i].holder().resetTo(fixedSizeBytes); + rowWriters[i].holder().reset(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 72eb1f6cf051..738b9a35d1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -132,8 +132,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private BufferHolder bufferHolder = new BufferHolder(unsafeRow); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -181,9 +181,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; bufferHolder.reset(); - rowWriter.initialize(bufferHolder, $numFields); + rowWriter.zeroOutNullBytes(); ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); + unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index bd2d17c0189e..430257f60d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -98,16 +98,15 @@ private[sql] class TextRelation( sqlContext.sparkContext.hadoopRDD( conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) .mapPartitions { iter => - val bufferHolder = new BufferHolder - val unsafeRowWriter = new UnsafeRowWriter val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) iter.map { case (_, line) => // Writes to an UnsafeRow directly bufferHolder.reset() - unsafeRowWriter.initialize(bufferHolder, 1) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) + unsafeRow.setTotalSize(bufferHolder.totalSize()) unsafeRow } } From 109061f7ad27225669cbe609ec38756b31d4e1b9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 25 Jan 2016 17:58:11 -0800 Subject: [PATCH 010/131] [SPARK-12936][SQL] Initial bloom filter implementation This PR adds an initial implementation of bloom filter in the newly added sketch module. The implementation is based on the [`BloomFilter` class in guava](https://code.google.com/p/guava-libraries/source/browse/guava/src/com/google/common/hash/BloomFilter.java). Some difference from the design doc: * expose `bitSize` instead of `sizeInBytes` to user. * always need the `expectedInsertions` parameter when create bloom filter. Author: Wenchen Fan Closes #10883 from cloud-fan/bloom-filter. --- .../apache/spark/util/sketch/BitArray.java | 94 ++++++++++ .../apache/spark/util/sketch/BloomFilter.java | 153 ++++++++++++++++ .../spark/util/sketch/BloomFilterImpl.java | 164 ++++++++++++++++++ .../spark/util/sketch/BitArraySuite.scala | 77 ++++++++ .../spark/util/sketch/BloomFilterSuite.scala | 114 ++++++++++++ 5 files changed, 602 insertions(+) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java create mode 100644 common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala create mode 100644 common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java new file mode 100644 index 000000000000..1bc665ad54b7 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.util.Arrays; + +public final class BitArray { + private final long[] data; + private long bitCount; + + static int numWords(long numBits) { + long numWords = (long) Math.ceil(numBits / 64.0); + if (numWords > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); + } + return (int) numWords; + } + + BitArray(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive"); + } + this.data = new long[numWords(numBits)]; + long bitCount = 0; + for (long value : data) { + bitCount += Long.bitCount(value); + } + this.bitCount = bitCount; + } + + /** Returns true if the bit changed value. */ + boolean set(long index) { + if (!get(index)) { + data[(int) (index >>> 6)] |= (1L << index); + bitCount++; + return true; + } + return false; + } + + boolean get(long index) { + return (data[(int) (index >>> 6)] & (1L << index)) != 0; + } + + /** Number of bits */ + long bitSize() { + return (long) data.length * Long.SIZE; + } + + /** Number of set bits (1s) */ + long cardinality() { + return bitCount; + } + + /** Combines the two BitArrays using bitwise OR. */ + void putAll(BitArray array) { + assert data.length == array.data.length : "BitArrays must be of equal length when merging"; + long bitCount = 0; + for (int i = 0; i < data.length; i++) { + data[i] |= array.data[i]; + bitCount += Long.bitCount(data[i]); + } + this.bitCount = bitCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || !(o instanceof BitArray)) return false; + + BitArray bitArray = (BitArray) o; + return Arrays.equals(data, bitArray.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java new file mode 100644 index 000000000000..38949c6311df --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether + * an element is a member of a set. It returns false when the element is definitely not in the + * set, returns true when the element is probably in the set. + * + * Internally a Bloom filter is initialized with 2 information: how many space to use(number of + * bits) and how many hash values to calculate for each record. To get as lower false positive + * probability as possible, user should call {@link BloomFilter#create} to automatically pick a + * best combination of these 2 parameters. + * + * Currently the following data types are supported: + *
    + *
  • {@link Byte}
  • + *
  • {@link Short}
  • + *
  • {@link Integer}
  • + *
  • {@link Long}
  • + *
  • {@link String}
  • + *
+ * + * The implementation is largely based on the {@code BloomFilter} class from guava. + */ +public abstract class BloomFilter { + /** + * Returns the false positive probability, i.e. the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that + * has not actually been put in the {@code BloomFilter}. + * + *

Ideally, this number should be close to the {@code fpp} parameter + * passed in to create this bloom filter, or smaller. If it is + * significantly higher, it is usually the case that too many elements (more than + * expected) have been put in the {@code BloomFilter}, degenerating it. + */ + public abstract double expectedFpp(); + + /** + * Returns the number of bits in the underlying bit array. + */ + public abstract long bitSize(); + + /** + * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@link #mightContain(Object)} with the same element will always return {@code true}. + * + * @return true if the bloom filter's bits changed as a result of this operation. If the bits + * changed, this is definitely the first time {@code object} has been added to the + * filter. If the bits haven't changed, this might be the first time {@code object} + * has been added to the filter. Note that {@code put(t)} always returns the + * opposite result to what {@code mightContain(t)} would have returned at the time + * it is called. + */ + public abstract boolean put(Object item); + + /** + * Determines whether a given bloom filter is compatible with this bloom filter. For two + * bloom filters to be compatible, they must have the same bit size. + * + * @param other The bloom filter to check for compatibility. + */ + public abstract boolean isCompatible(BloomFilter other); + + /** + * Combines this bloom filter with another bloom filter by performing a bitwise OR of the + * underlying data. The mutations happen to this instance. Callers must ensure the + * bloom filters are appropriately sized to avoid saturating them. + * + * @param other The bloom filter to combine this bloom filter with. It is not mutated. + * @throws IllegalArgumentException if {@code isCompatible(that) == false} + */ + public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; + + /** + * Returns {@code true} if the element might have been put in this Bloom filter, + * {@code false} if this is definitely not the case. + */ + public abstract boolean mightContain(Object item); + + /** + * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the + * expected insertions and total number of bits in the Bloom filter. + * + * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. + * + * @param n expected insertions (must be positive) + * @param m total number of bits in Bloom filter (must be positive) + */ + private static int optimalNumOfHashFunctions(long n, long m) { + // (m / n) * log(2), but avoid truncation due to division! + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + /** + * Computes m (total bits of Bloom filter) which is expected to achieve, for the specified + * expected insertions, the required false positive probability. + * + * See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula. + * + * @param n expected insertions (must be positive) + * @param p false positive rate (must be 0 < p < 1) + */ + private static long optimalNumOfBits(long n, double p) { + return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + static final double DEFAULT_FPP = 0.03; + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}. + */ + public static BloomFilter create(long expectedNumItems) { + return create(expectedNumItems, DEFAULT_FPP); + } + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick + * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter. + */ + public static BloomFilter create(long expectedNumItems, double fpp) { + assert fpp > 0.0 : "False positive probability must be > 0.0"; + assert fpp < 1.0 : "False positive probability must be < 1.0"; + long numBits = optimalNumOfBits(expectedNumItems, fpp); + return create(expectedNumItems, numBits); + } + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code numBits}, it will + * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. + */ + public static BloomFilter create(long expectedNumItems, long numBits) { + assert expectedNumItems > 0 : "Expected insertions must be > 0"; + assert numBits > 0 : "number of bits must be > 0"; + int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); + return new BloomFilterImpl(numHashFunctions, numBits); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java new file mode 100644 index 000000000000..bbd6cf719dc0 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.UnsupportedEncodingException; + +public class BloomFilterImpl extends BloomFilter { + + private final int numHashFunctions; + private final BitArray bits; + + BloomFilterImpl(int numHashFunctions, long numBits) { + this.numHashFunctions = numHashFunctions; + this.bits = new BitArray(numBits); + } + + @Override + public double expectedFpp() { + return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions); + } + + @Override + public long bitSize() { + return bits.bitSize(); + } + + private static long hashObjectToLong(Object item) { + if (item instanceof String) { + try { + byte[] bytes = ((String) item).getBytes("utf-8"); + return hashBytesToLong(bytes); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException("Only support utf-8 string", e); + } + } else { + long longValue; + + if (item instanceof Long) { + longValue = (Long) item; + } else if (item instanceof Integer) { + longValue = ((Integer) item).longValue(); + } else if (item instanceof Short) { + longValue = ((Short) item).longValue(); + } else if (item instanceof Byte) { + longValue = ((Byte) item).longValue(); + } else { + throw new IllegalArgumentException( + "Support for " + item.getClass().getName() + " not implemented" + ); + } + + int h1 = Murmur3_x86_32.hashLong(longValue, 0); + int h2 = Murmur3_x86_32.hashLong(longValue, h1); + return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + } + } + + private static long hashBytesToLong(byte[] bytes) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + } + + @Override + public boolean put(Object item) { + long bitSize = bits.bitSize(); + + // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash + // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy for long type, it hash the input long + // element with every i to produce n hash values. + long hash64 = hashObjectToLong(item); + int h1 = (int) (hash64 >> 32); + int h2 = (int) hash64; + + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContain(Object item) { + long bitSize = bits.bitSize(); + long hash64 = hashObjectToLong(item); + int h1 = (int) (hash64 >> 32); + int h2 = (int) hash64; + + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean isCompatible(BloomFilter other) { + if (other == null) { + return false; + } + + if (!(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions; + } + + @Override + public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException { + // Duplicates the logic of `isCompatible` here to provide better error message. + if (other == null) { + throw new IncompatibleMergeException("Cannot merge null bloom filter"); + } + + if (!(other instanceof BloomFilter)) { + throw new IncompatibleMergeException( + "Cannot merge bloom filter of class " + other.getClass().getName() + ); + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + if (this.bitSize() != that.bitSize()) { + throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size"); + } + + if (this.numHashFunctions != that.numHashFunctions) { + throw new IncompatibleMergeException( + "Cannot merge bloom filters with different number of hash functions"); + } + + this.bits.putAll(that.bits); + return this; + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala new file mode 100644 index 000000000000..ff728f0ebcb8 --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite + + test("error case when create BitArray") { + intercept[IllegalArgumentException](new BitArray(0)) + intercept[IllegalArgumentException](new BitArray(64L * Integer.MAX_VALUE + 1)) + } + + test("bitSize") { + assert(new BitArray(64).bitSize() == 64) + // BitArray is word-aligned, so 65~128 bits need 2 long to store, which is 128 bits. + assert(new BitArray(65).bitSize() == 128) + assert(new BitArray(127).bitSize() == 128) + assert(new BitArray(128).bitSize() == 128) + } + + test("set") { + val bitArray = new BitArray(64) + assert(bitArray.set(1)) + // Only returns true if the bit changed. + assert(!bitArray.set(1)) + assert(bitArray.set(2)) + } + + test("normal operation") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val bitArray = new BitArray(320) + val indexes = (1 to 100).map(_ => r.nextInt(320).toLong).distinct + + indexes.foreach(bitArray.set) + indexes.foreach(i => assert(bitArray.get(i))) + assert(bitArray.cardinality() == indexes.length) + } + + test("merge") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val bitArray1 = new BitArray(64 * 6) + val bitArray2 = new BitArray(64 * 6) + + val indexes1 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct + val indexes2 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct + + indexes1.foreach(bitArray1.set) + indexes2.foreach(bitArray2.set) + + bitArray1.putAll(bitArray2) + indexes1.foreach(i => assert(bitArray1.get(i))) + indexes2.foreach(i => assert(bitArray1.get(i))) + assert(bitArray1.cardinality() == (indexes1 ++ indexes2).distinct.length) + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala new file mode 100644 index 000000000000..d2de509f1951 --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite + private final val EPSILON = 0.01 + + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + test(s"accuracy - $typeName") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + val fpp = 0.05 + val numInsertion = numItems / 10 + + val allItems = Array.fill(numItems)(itemGen(r)) + + val filter = BloomFilter.create(numInsertion, fpp) + + // insert first `numInsertion` items. + allItems.take(numInsertion).foreach(filter.put) + + // false negative is not allowed. + assert(allItems.take(numInsertion).forall(filter.mightContain)) + + // The number of inserted items doesn't exceed `expectedNumItems`, so the `expectedFpp` + // should not be significantly higher than the one we passed in to create this bloom filter. + assert(filter.expectedFpp() - fpp < EPSILON) + + val errorCount = allItems.drop(numInsertion).count(filter.mightContain) + + // Also check the actual fpp is not significantly higher than we expected. + val actualFpp = errorCount.toDouble / (numItems - numInsertion) + assert(actualFpp - fpp < EPSILON) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val items1 = Array.fill(numItems / 2)(itemGen(r)) + val items2 = Array.fill(numItems / 2)(itemGen(r)) + + val filter1 = BloomFilter.create(numItems) + items1.foreach(filter1.put) + + val filter2 = BloomFilter.create(numItems) + items2.foreach(filter2.put) + + filter1.mergeInPlace(filter2) + + // After merge, `filter1` has `numItems` items which doesn't exceed `expectedNumItems`, so the + // `expectedFpp` should not be significantly higher than the default one. + assert(filter1.expectedFpp() - BloomFilter.DEFAULT_FPP < EPSILON) + + items1.foreach(i => assert(filter1.mightContain(i))) + items2.foreach(i => assert(filter1.mightContain(i))) + } + } + + def testItemType[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + testAccuracy[T](typeName, numItems)(itemGen) + testMergeInPlace[T](typeName, numItems)(itemGen) + } + + testItemType[Byte]("Byte", 160) { _.nextInt().toByte } + + testItemType[Short]("Short", 1000) { _.nextInt().toShort } + + testItemType[Int]("Int", 100000) { _.nextInt() } + + testItemType[Long]("Long", 100000) { _.nextLong() } + + testItemType[String]("String", 100000) { r => r.nextString(r.nextInt(512)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + BloomFilter.create(1000).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(1000, 6400) + val filter2 = BloomFilter.create(1000, 3200) + filter1.mergeInPlace(filter2) + } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(1000, 6400) + val filter2 = BloomFilter.create(2000, 6400) + filter1.mergeInPlace(filter2) + } + } +} From fdcc3512f7b45e5b067fc26cb05146f79c4a5177 Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 25 Jan 2016 18:23:47 -0800 Subject: [PATCH 011/131] [SPARK-12934] use try-with-resources for streams liancheng please take a look Author: tedyu Closes #10906 from tedyu/master. --- .../main/java/org/apache/spark/util/sketch/CountMinSketch.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 67938644d9f6..9f4ff42403c3 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -128,11 +128,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) /** * Writes out this {@link CountMinSketch} to an output stream in binary format. + * It is the caller's responsibility to close the stream */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. + * It is the caller's responsibility to close the stream */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); From b66afdeb5253913d916dcf159aaed4ffdc15fd4b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 25 Jan 2016 22:38:31 -0800 Subject: [PATCH 012/131] [SPARK-11922][PYSPARK][ML] Python api for ml.feature.quantile discretizer Add Python API for ml.feature.QuantileDiscretizer. One open question: Do we want to do this stuff to re-use the java model, create a new model, or use a different wrapper around the java model. cc brkyvz & mengxr Author: Holden Karau Closes #10085 from holdenk/SPARK-11937-SPARK-11922-Python-API-for-ml.feature.QuantileDiscretizer. --- python/pyspark/ml/feature.py | 89 ++++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1fa0eab384e7..f139d81bc490 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -30,10 +30,10 @@ __all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', - 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel'] + 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', + 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -991,6 +991,87 @@ def getDegree(self): return self.getOrDefault(self.degree) +@inherit_doc +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned + categorical features. The bin ranges are chosen by taking a sample of the data and dividing it + into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, + covering all real values. This attempts to find numBuckets partitions based on a sample of data, + but it may find fewer depending on the data sample values. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> qds = QuantileDiscretizer(numBuckets=2, + ... inputCol="values", outputCol="buckets") + >>> bucketizer = qds.fit(df) + >>> splits = bucketizer.getSplits() + >>> splits[0] + -inf + >>> print("%2.1f" % round(splits[1], 1)) + 0.4 + >>> bucketed = bucketizer.transform(df).head() + >>> bucketed.buckets + 0.0 + + .. versionadded:: 2.0.0 + """ + + # a placeholder to make it appear in the generated doc + numBuckets = Param(Params._dummy(), "numBuckets", + "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2. Default 2.") + + @keyword_only + def __init__(self, numBuckets=2, inputCol=None, outputCol=None): + """ + __init__(self, numBuckets=2, inputCol=None, outputCol=None) + """ + super(QuantileDiscretizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", + self.uid) + self.numBuckets = Param(self, "numBuckets", + "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2.") + self._setDefault(numBuckets=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numBuckets=2, inputCol=None, outputCol=None): + """ + setParams(self, numBuckets=2, inputCol=None, outputCol=None) + Set the params for the QuantileDiscretizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumBuckets(self, value): + """ + Sets the value of :py:attr:`numBuckets`. + """ + self._paramMap[self.numBuckets] = value + return self + + @since("2.0.0") + def getNumBuckets(self): + """ + Gets the value of numBuckets or its default value. + """ + return self.getOrDefault(self.numBuckets) + + def _create_model(self, java_model): + """ + Private method to convert the java_model to a Python model. + """ + return Bucketizer(splits=list(java_model.getSplits()), + inputCol=self.getInputCol(), + outputCol=self.getOutputCol()) + + @inherit_doc @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): From ae47ba718a280fc12720a71b981c38dbe647f35b Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 25 Jan 2016 22:41:52 -0800 Subject: [PATCH 013/131] [SPARK-12834] Change ser/de of JavaArray and JavaList https://issues.apache.org/jira/browse/SPARK-12834 We use `SerDe.dumps()` to serialize `JavaArray` and `JavaList` in `PythonMLLibAPI`, then deserialize them with `PickleSerializer` in Python side. However, there is no need to transform them in such an inefficient way. Instead of it, we can use type conversion to convert them, e.g. `list(JavaArray)` or `list(JavaList)`. What's more, there is an issue to Ser/De Scala Array as I said in https://issues.apache.org/jira/browse/SPARK-12780 Author: Xusen Yin Closes #10772 from yinxusen/SPARK-12834. --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 05f9a76d3267..088ec6a0c046 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1490,7 +1490,11 @@ private[spark] object SerDe extends Serializable { initialize() def dumps(obj: AnyRef): Array[Byte] = { - new Pickler().dumps(obj) + obj match { + // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834. + case array: Array[_] => new Pickler().dumps(array.toSeq.asJava) + case _ => new Pickler().dumps(obj) + } } def loads(bytes: Array[Byte]): AnyRef = { From 27c910f7f29087d1ac216d4933d641d6515fd6ad Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Jan 2016 22:53:34 -0800 Subject: [PATCH 014/131] [SPARK-10086][MLLIB][STREAMING][PYSPARK] ignore StreamingKMeans test in PySpark for now I saw several failures from recent PR builds, e.g., https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/50015/consoleFull. This PR marks the test as ignored and we will fix the flakyness in SPARK-10086. gliptak Do you know why the test failure didn't show up in the Jenkins "Test Result"? cc: jkbradley Author: Xiangrui Meng Closes #10909 from mengxr/SPARK-10086. --- python/pyspark/mllib/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 79ce4959c926..25a7c29982b3 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -1189,6 +1189,7 @@ def condition(): self._eventually(condition, catch_assertions=True) + @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") def test_trainOn_predictOn(self): """Test that prediction happens on the updated model.""" stkm = StreamingKMeans(decayFactor=0.0, k=2) From d54cfed5a6953a9ce2b9de2f31ee2d673cb5cc62 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 26 Jan 2016 00:51:08 -0800 Subject: [PATCH 015/131] [SQL][MINOR] A few minor tweaks to CSV reader. This pull request simply fixes a few minor coding style issues in csv, as I was reviewing the change post-hoc. Author: Reynold Xin Closes #10919 from rxin/csv-minor. --- .../datasources/csv/CSVInferSchema.scala | 21 +++++++------------ .../datasources/csv/CSVRelation.scala | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 0aa4539e6051..ace8cd7ad864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -30,16 +30,15 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ -private[sql] object CSVInferSchema { +private[csv] object CSVInferSchema { /** * Similar to the JSON schema inference * 1. Infer type of each row * 2. Merge row types to find common type * 3. Replace any null types with string type - * TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670] */ - def apply( + def infer( tokenRdd: RDD[Array[String]], header: Array[String], nullValue: String = ""): StructType = { @@ -65,10 +64,7 @@ private[sql] object CSVInferSchema { rowSoFar } - private[csv] def mergeRowTypes( - first: Array[DataType], - second: Array[DataType]): Array[DataType] = { - + def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case ((a, b)) => val tpe = findTightestCommonType(a, b).getOrElse(StringType) tpe match { @@ -82,8 +78,7 @@ private[sql] object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - private[csv] def inferField( - typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { + def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { if (field == null || field.isEmpty || field == nullValue) { typeSoFar } else { @@ -155,7 +150,8 @@ private[sql] object CSVInferSchema { } } -object CSVTypeCast { + +private[csv] object CSVTypeCast { /** * Casts given string datum to specified type. @@ -167,7 +163,7 @@ object CSVTypeCast { * @param datum string value * @param castType SparkSQL type */ - private[csv] def castTo( + def castTo( datum: String, castType: DataType, nullable: Boolean = true, @@ -201,10 +197,9 @@ object CSVTypeCast { * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one * character. - * */ @throws[IllegalArgumentException] - private[csv] def toChar(str: String): Char = { + def toChar(str: String): Char = { if (str.charAt(0) == '\\') { str.charAt(1) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 5959f7cc5051..dc449fea956f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -139,7 +139,7 @@ private[csv] class CSVRelation( val parsedRdd = tokenRdd(header, paths) if (params.inferSchemaFlag) { - CSVInferSchema(parsedRdd, header, params.nullValue) + CSVInferSchema.infer(parsedRdd, header, params.nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => From 6743de3a98e3f0d0e6064ca1872fa88c3aeaa143 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jan 2016 00:53:05 -0800 Subject: [PATCH 016/131] [SPARK-12937][SQL] bloom filter serialization This PR adds serialization support for BloomFilter. A version number is added to version the serialized binary format. Author: Wenchen Fan Closes #10920 from cloud-fan/bloom-filter. --- .../apache/spark/util/sketch/BitArray.java | 46 +++++++++++++----- .../apache/spark/util/sketch/BloomFilter.java | 42 +++++++++++++++- .../spark/util/sketch/BloomFilterImpl.java | 48 ++++++++++++++++++- .../spark/util/sketch/CountMinSketch.java | 25 ++++++---- .../spark/util/sketch/CountMinSketchImpl.java | 22 +-------- .../spark/util/sketch/BloomFilterSuite.scala | 20 ++++++++ 6 files changed, 159 insertions(+), 44 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 1bc665ad54b7..2a0484e324b1 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -17,6 +17,9 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.Arrays; public final class BitArray { @@ -24,6 +27,9 @@ public final class BitArray { private long bitCount; static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); + } long numWords = (long) Math.ceil(numBits / 64.0); if (numWords > Integer.MAX_VALUE) { throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); @@ -32,13 +38,14 @@ static int numWords(long numBits) { } BitArray(long numBits) { - if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive"); - } - this.data = new long[numWords(numBits)]; + this(new long[numWords(numBits)]); + } + + private BitArray(long[] data) { + this.data = data; long bitCount = 0; - for (long value : data) { - bitCount += Long.bitCount(value); + for (long word : data) { + bitCount += Long.bitCount(word); } this.bitCount = bitCount; } @@ -78,13 +85,28 @@ void putAll(BitArray array) { this.bitCount = bitCount; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof BitArray)) return false; + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } + + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } - BitArray bitArray = (BitArray) o; - return Arrays.equals(data, bitArray.data); + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); } @Override diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 38949c6311df..00378d58518f 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -17,6 +17,10 @@ package org.apache.spark.util.sketch; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + /** * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether * an element is a member of a set. It returns false when the element is definitely not in the @@ -39,6 +43,28 @@ * The implementation is largely based on the {@code BloomFilter} class from guava. */ public abstract class BloomFilter { + + public enum Version { + /** + * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total number of words of the underlying bit array (32 bit) + * - The words/longs (numWords * 64 bit) + * - Number of hash functions (32 bit) + */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + /** * Returns the false positive probability, i.e. the probability that * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that @@ -83,7 +109,7 @@ public abstract class BloomFilter { * bloom filters are appropriately sized to avoid saturating them. * * @param other The bloom filter to combine this bloom filter with. It is not mutated. - * @throws IllegalArgumentException if {@code isCompatible(that) == false} + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} */ public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; @@ -93,6 +119,20 @@ public abstract class BloomFilter { */ public abstract boolean mightContain(Object item); + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. + * It is the caller's responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link BloomFilter} from an input stream. + * It is the caller's responsibility to close the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + /** * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index bbd6cf719dc0..1c08d07afaea 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -17,7 +17,7 @@ package org.apache.spark.util.sketch; -import java.io.UnsupportedEncodingException; +import java.io.*; public class BloomFilterImpl extends BloomFilter { @@ -25,8 +25,32 @@ public class BloomFilterImpl extends BloomFilter { private final BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; this.numHashFunctions = numHashFunctions; - this.bits = new BitArray(numBits); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; } @Override @@ -161,4 +185,24 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep this.bits.putAll(that.bits); return this; } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + bits.writeTo(dos); + dos.writeInt(numHashFunctions); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 9f4ff42403c3..00c0b1b9e2db 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -55,10 +55,21 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { - /** - * Version number of the serialized binary format. - */ + public enum Version { + /** + * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ V1(1); private final int versionNumber; @@ -67,13 +78,11 @@ public enum Version { this.versionNumber = versionNumber; } - public int getVersionNumber() { + int getVersionNumber() { return versionNumber; } } - public abstract Version version(); - /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -128,13 +137,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) /** * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 0209446ea3b1..d08809605a93 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -26,21 +26,6 @@ import java.util.Arrays; import java.util.Random; -/* - * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian - * order): - * - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) - */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; @@ -112,11 +97,6 @@ public int hashCode() { return hash; } - @Override - public Version version() { - return Version.V1; - } - private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -327,7 +307,7 @@ public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMerg public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); - dos.writeInt(version().getVersionNumber()); + dos.writeInt(Version.V1.getVersionNumber()); dos.writeLong(this.totalCount); dos.writeInt(this.depth); diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index d2de509f1951..a0408d2da4df 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite private final val EPSILON = 0.01 + // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(filter: BloomFilter): Unit = { + val out = new ByteArrayOutputStream() + filter.writeTo(out) + out.close() + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = BloomFilter.readFrom(in) + in.close() + + assert(filter == deserialized) + } + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { test(s"accuracy - $typeName") { // use a fixed seed to make the test predictable. @@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite // Also check the actual fpp is not significantly higher than we expected. val actualFpp = errorCount.toDouble / (numItems - numInsertion) assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) } } @@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite items1.foreach(i => assert(filter1.mightContain(i))) items2.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) } } From 5936bf9fa85ccf7f0216145356140161c2801682 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Jan 2016 11:36:00 +0000 Subject: [PATCH 017/131] [SPARK-12961][CORE] Prevent snappy-java memory leak JIRA: https://issues.apache.org/jira/browse/SPARK-12961 To prevent memory leak in snappy-java, just call the method once and cache the result. After the library releases new version, we can remove this object. JoshRosen Author: Liang-Chi Hsieh Closes #10875 from viirya/prevent-snappy-memory-leak. --- .../apache/spark/io/CompressionCodec.scala | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 717804626f85..ae014becef75 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -149,12 +149,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { - - try { - Snappy.getNativeLibraryVersion - } catch { - case e: Error => throw new IllegalArgumentException(e) - } + val version = SnappyCompressionCodec.version override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt @@ -164,6 +159,19 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } +/** + * Object guards against memory leak bug in snappy-java library: + * (https://github.com/xerial/snappy-java/issues/131). + * Before a new version of the library, we only call the method once and cache the result. + */ +private final object SnappyCompressionCodec { + private lazy val version: String = try { + Snappy.getNativeLibraryVersion + } catch { + case e: Error => throw new IllegalArgumentException(e) + } +} + /** * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version From 649e9d0f5b2d5fc13f2dd5be675331510525927f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 26 Jan 2016 11:55:28 +0000 Subject: [PATCH 018/131] [SPARK-3369][CORE][STREAMING] Java mapPartitions Iterator->Iterable is inconsistent with Scala's Iterator->Iterator Fix Java function API methods for flatMap and mapPartitions to require producing only an Iterator, not Iterable. Also fix DStream.flatMap to require a function producing TraversableOnce only, not Traversable. CC rxin pwendell for API change; tdas since it also touches streaming. Author: Sean Owen Closes #10413 from srowen/SPARK-3369. --- .../api/java/function/CoGroupFunction.java | 2 +- .../java/function/DoubleFlatMapFunction.java | 3 +- .../api/java/function/FlatMapFunction.java | 3 +- .../api/java/function/FlatMapFunction2.java | 3 +- .../java/function/FlatMapGroupsFunction.java | 2 +- .../java/function/MapPartitionsFunction.java | 2 +- .../java/function/PairFlatMapFunction.java | 3 +- .../apache/spark/api/java/JavaRDDLike.scala | 20 ++++++------ .../java/org/apache/spark/JavaAPISuite.java | 24 +++++++------- docs/streaming-programming-guide.md | 4 +-- .../apache/spark/examples/JavaPageRank.java | 18 +++++------ .../apache/spark/examples/JavaWordCount.java | 5 +-- .../streaming/JavaActorWordCount.java | 5 +-- .../streaming/JavaCustomReceiver.java | 7 +++-- .../streaming/JavaDirectKafkaWordCount.java | 6 ++-- .../streaming/JavaKafkaWordCount.java | 9 +++--- .../streaming/JavaNetworkWordCount.java | 11 ++++--- .../JavaRecoverableNetworkWordCount.java | 6 ++-- .../streaming/JavaSqlNetworkWordCount.java | 8 ++--- .../JavaStatefulNetworkWordCount.java | 5 +-- .../JavaTwitterHashTagJoinSentiments.java | 5 +-- .../java/org/apache/spark/Java8APISuite.java | 2 +- .../streaming/JavaKinesisWordCountASL.java | 9 ++++-- project/MimaExcludes.scala | 31 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 +-- .../apache/spark/sql/JavaDatasetSuite.java | 25 +++++++-------- .../streaming/api/java/JavaDStreamLike.scala | 10 +++--- .../spark/streaming/dstream/DStream.scala | 2 +- .../streaming/dstream/FlatMappedDStream.scala | 2 +- .../apache/spark/streaming/JavaAPISuite.java | 20 ++++++------ 30 files changed, 146 insertions(+), 110 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java index 279639af5d43..07aebb75e8f4 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -25,5 +25,5 @@ * Datasets. */ public interface CoGroupFunction extends Serializable { - Iterable call(K key, Iterator left, Iterator right) throws Exception; + Iterator call(K key, Iterator left, Iterator right) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 57fd0a7a8049..576087b6f428 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that returns zero or more records of type Double from each input record. */ public interface DoubleFlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index ef0d1824121e..2d8ea6d1a5a7 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - Iterable call(T t) throws Exception; + Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index 14a98a38ef5a..fc97b63f825d 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - Iterable call(T1 t1, T2 t2) throws Exception; + Iterator call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index d7a80e7b129b..bae574ab5755 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -24,5 +24,5 @@ * A function that returns zero or more output records from each grouping key and its values. */ public interface FlatMapGroupsFunction extends Serializable { - Iterable call(K key, Iterator values) throws Exception; + Iterator call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java index 6cb569ce0cb6..cf9945a215af 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -24,5 +24,5 @@ * Base interface for function used in Dataset's mapPartitions. */ public interface MapPartitionsFunction extends Serializable { - Iterable call(Iterator input) throws Exception; + Iterator call(Iterator input) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java index 691ef2eceb1f..51eed2e67b9f 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -18,6 +18,7 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; import scala.Tuple2; @@ -26,5 +27,5 @@ * key-value pairs are represented as scala.Tuple2 objects. */ public interface PairFlatMapFunction extends Serializable { - public Iterable> call(T t) throws Exception; + Iterator> call(T t) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0f8d13cf5cc2..7340defabfe5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -121,7 +121,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { - def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -130,7 +130,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { - def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[jl.Double] = (x: T) => f.call(x).asScala new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } @@ -139,7 +139,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -149,7 +149,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -160,7 +160,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) @@ -171,7 +171,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } @@ -182,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -193,7 +193,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) @@ -205,7 +205,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) @@ -290,7 +290,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 44d5cac7c2de..8117ad9e6064 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -880,8 +880,8 @@ public void flatMap() { "The quick brown fox jumps over the lazy dog.")); JavaRDD words = rdd.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(x.split(" ")); + public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); } }); Assert.assertEquals("Hello", words.first()); @@ -890,12 +890,12 @@ public Iterable call(String x) { JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String s) { + public Iterator> call(String s) { List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { pairs.add(new Tuple2<>(word, word)); } - return pairs; + return pairs.iterator(); } } ); @@ -904,12 +904,12 @@ public Iterable> call(String s) { JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override - public Iterable call(String s) { + public Iterator call(String s) { List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } - return lengths; + return lengths.iterator(); } }); Assert.assertEquals(5.0, doubles.first(), 0.01); @@ -930,8 +930,8 @@ public void mapsFromPairsToPairs() { JavaPairRDD swapped = pairRDD.flatMapToPair( new PairFlatMapFunction, String, Integer>() { @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); + public Iterator> call(Tuple2 item) { + return Collections.singletonList(item.swap()).iterator(); } }); swapped.collect(); @@ -951,12 +951,12 @@ public void mapPartitions() { JavaRDD partitionSums = rdd.mapPartitions( new FlatMapFunction, Integer>() { @Override - public Iterable call(Iterator iter) { + public Iterator call(Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); } - return Collections.singletonList(sum); + return Collections.singletonList(sum).iterator(); } }); Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); @@ -1367,8 +1367,8 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = new FlatMapFunction2, Iterator, Integer>() { @Override - public Iterable call(Iterator i, Iterator s) { - return Arrays.asList(Iterators.size(i), Iterators.size(s)); + public Iterator call(Iterator i, Iterator s) { + return Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); } }; diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 93c34efb6662..7e681b67cf0c 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -165,8 +165,8 @@ space into words. // Split each line into words JavaDStream words = lines.flatMap( new FlatMapFunction() { - @Override public Iterable call(String x) { - return Arrays.asList(x.split(" ")); + @Override public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); } }); {% endhighlight %} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index a5db8accdf13..635fb6a373c4 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -17,7 +17,10 @@ package org.apache.spark.examples; - +import java.util.ArrayList; +import java.util.List; +import java.util.Iterator; +import java.util.regex.Pattern; import scala.Tuple2; @@ -32,11 +35,6 @@ import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; -import java.util.ArrayList; -import java.util.List; -import java.util.Iterator; -import java.util.regex.Pattern; - /** * Computes the PageRank of URLs from an input file. Input file should * be in format of: @@ -108,13 +106,13 @@ public Double call(Iterable rs) { JavaPairRDD contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { @Override - public Iterable> call(Tuple2, Double> s) { + public Iterator> call(Tuple2, Double> s) { int urlCount = Iterables.size(s._1); - List> results = new ArrayList>(); + List> results = new ArrayList<>(); for (String n : s._1) { - results.add(new Tuple2(n, s._2() / urlCount)); + results.add(new Tuple2<>(n, s._2() / urlCount)); } - return results; + return results.iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 9a6a944f7ede..d746a3d2b677 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -27,6 +27,7 @@ import org.apache.spark.api.java.function.PairFunction; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -46,8 +47,8 @@ public static void main(String[] args) throws Exception { JavaRDD words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) { - return Arrays.asList(SPACE.split(s)); + public Iterator call(String s) { + return Arrays.asList(SPACE.split(s)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java index 62e563380a9e..cf774667f6c5 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; +import java.util.Iterator; import scala.Tuple2; @@ -120,8 +121,8 @@ public static void main(String[] args) { // compute wordcount lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) { - return Arrays.asList(s.split("\\s+")); + public Iterator call(String s) { + return Arrays.asList(s.split("\\s+")).iterator(); } }).mapToPair(new PairFunction() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 4b50fbf59f80..3d668adcf815 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -17,7 +17,6 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; import com.google.common.io.Closeables; import org.apache.spark.SparkConf; @@ -37,6 +36,8 @@ import java.io.InputStreamReader; import java.net.ConnectException; import java.net.Socket; +import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; /** @@ -74,8 +75,8 @@ public static void main(String[] args) { new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index f9a5e7f69ffe..5107500a127c 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -20,11 +20,11 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.collect.Lists; import kafka.serializer.StringDecoder; import org.apache.spark.SparkConf; @@ -87,8 +87,8 @@ public String call(Tuple2 tuple2) { }); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 337f8ffb5bfb..0df4cb40a9a7 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -17,20 +17,19 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; import java.util.Map; import java.util.HashMap; import java.util.regex.Pattern; - import scala.Tuple2; -import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -88,8 +87,8 @@ public String call(Tuple2 tuple2) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index 3e9f0f4b8f12..b82b319acb73 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -17,8 +17,11 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; +import java.util.regex.Pattern; + import scala.Tuple2; -import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -31,8 +34,6 @@ import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import java.util.regex.Pattern; - /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. * @@ -67,8 +68,8 @@ public static void main(String[] args) { args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index bc963a02be60..bc8cbcdef727 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -21,11 +21,11 @@ import java.io.IOException; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.collect.Lists; import com.google.common.io.Files; import org.apache.spark.Accumulator; @@ -138,8 +138,8 @@ private static JavaStreamingContext createContext(String ip, JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 084f68a8be43..f0228f5e6345 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -17,10 +17,10 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; -import com.google.common.collect.Lists; - import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; @@ -72,8 +72,8 @@ public static void main(String[] args) { args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index f52cc7c20576..6beab90f086d 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -73,8 +74,8 @@ public static void main(String[] args) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java index d869768026ae..f0ae9a99bae4 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java @@ -34,6 +34,7 @@ import twitter4j.Status; import java.util.Arrays; +import java.util.Iterator; import java.util.List; /** @@ -70,8 +71,8 @@ public static void main(String[] args) { JavaDStream words = stream.flatMap(new FlatMapFunction() { @Override - public Iterable call(Status s) { - return Arrays.asList(s.getText().split(" ")); + public Iterator call(Status s) { + return Arrays.asList(s.getText().split(" ")).iterator(); } }); diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 27d494ce355f..c0b58e713f64 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -294,7 +294,7 @@ public void zipPartitions() { sizeS += 1; s.next(); } - return Arrays.asList(sizeI, sizeS); + return Arrays.asList(sizeI, sizeS).iterator(); }; JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 06e0ff28afd9..64e044aa8e4a 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -16,7 +16,10 @@ */ package org.apache.spark.examples.streaming; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -38,7 +41,6 @@ import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.google.common.collect.Lists; /** * Consumes messages from a Amazon Kinesis streams and does wordcount. @@ -154,8 +156,9 @@ public static void main(String[] args) { // Convert each line of Array[Byte] to String, and split into words JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + public Iterator call(byte[] line) { + String s = new String(line, StandardCharsets.UTF_8); + return Arrays.asList(WORD_SEPARATOR.split(s)).iterator(); } }); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 501456b04317..643bee69694d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -60,6 +60,37 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ + Seq( + // SPARK-3369 Fix Iterable/Iterator in Java API + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapGroupsFunction.call") + ) ++ Seq( // SPARK-4819 replace Guava Optional ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd99c399571c..f182270a0872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -346,7 +346,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) } @@ -366,7 +366,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (T) => Iterable[U] = x => f.call(x).asScala + val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3c0f25a5dc53..a6fb62c17d59 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -111,24 +111,24 @@ public Integer call(String v) throws Exception { Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @Override - public Iterable call(Iterator it) throws Exception { - List ls = new LinkedList(); + public Iterator call(Iterator it) { + List ls = new LinkedList<>(); while (it.hasNext()) { - ls.add(it.next().toUpperCase()); + ls.add(it.next().toUpperCase(Locale.ENGLISH)); } - return ls; + return ls.iterator(); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset flatMapped = ds.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) throws Exception { - List ls = new LinkedList(); + public Iterator call(String s) { + List ls = new LinkedList<>(); for (char c : s.toCharArray()) { ls.add(String.valueOf(c)); } - return ls; + return ls.iterator(); } }, Encoders.STRING()); Assert.assertEquals( @@ -192,12 +192,12 @@ public String call(Integer key, Iterator values) throws Exception { Dataset flatMapped = grouped.flatMapGroups( new FlatMapGroupsFunction() { @Override - public Iterable call(Integer key, Iterator values) throws Exception { + public Iterator call(Integer key, Iterator values) { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } - return Collections.singletonList(sb.toString()); + return Collections.singletonList(sb.toString()).iterator(); } }, Encoders.STRING()); @@ -228,10 +228,7 @@ public Integer call(Integer v) throws Exception { grouped2, new CoGroupFunction() { @Override - public Iterable call( - Integer key, - Iterator left, - Iterator right) throws Exception { + public Iterator call(Integer key, Iterator left, Iterator right) { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); @@ -240,7 +237,7 @@ public Iterable call( while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()); + return Collections.singletonList(sb.toString()).iterator(); } }, Encoders.STRING()); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a791a474c673..f10de485d0f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -166,8 +166,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * and then flattening the results */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { - import scala.collection.JavaConverters._ - def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -176,8 +175,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * and then flattening the results */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { - import scala.collection.JavaConverters._ - def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -189,7 +187,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -202,7 +200,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1dfb4e7abc0e..db79eeab9c0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -550,7 +550,7 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope { + def flatMap[U: ClassTag](flatMapFunc: T => TraversableOnce[U]): DStream[U] = ssc.withScope { new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 96a444a7baa5..d60a6179782e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -25,7 +25,7 @@ import org.apache.spark.streaming.{Duration, Time} private[streaming] class FlatMappedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], - flatMapFunc: T => Traversable[U] + flatMapFunc: T => TraversableOnce[U] ) extends DStream[U](parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 4dbcef293487..806cea24cadd 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -271,12 +271,12 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override - public Iterable call(Iterator in) { + public Iterator call(Iterator in) { StringBuilder out = new StringBuilder(); while (in.hasNext()) { out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Arrays.asList(out.toString()); + return Arrays.asList(out.toString()).iterator(); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -759,8 +759,8 @@ public void testFlatMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(x.split("(?!^)")); + public Iterator call(String x) { + return Arrays.asList(x.split("(?!^)")).iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -846,12 +846,12 @@ public void testPairFlatMap() { JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) { + public Iterator> call(String in) { List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { out.add(new Tuple2<>(in.length(), letter)); } - return out; + return out.iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -1019,13 +1019,13 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) { + public Iterator> call(Iterator> in) { List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); } - return out; + return out.iterator(); } }); @@ -1089,12 +1089,12 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) { + public Iterator> call(Tuple2 in) { List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { out.add(new Tuple2<>(in._2(), s.toString())); } - return out; + return out.iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); From ae0309a8812a4fade3a0ea67d8986ca870aeb9eb Mon Sep 17 00:00:00 2001 From: zhuol Date: Tue, 26 Jan 2016 09:40:02 -0600 Subject: [PATCH 019/131] [SPARK-10911] Executors should System.exit on clean shutdown. Call system.exit explicitly to make sure non-daemon user threads terminate. Without this, user applications might live forever if the cluster manager does not appropriately kill them. E.g., YARN had this bug: HADOOP-12441. Author: zhuol Closes #9946 from zhuoliu/10911. --- .../org/apache/spark/executor/CoarseGrainedExecutorBackend.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index e3a6c4c07a75..136cf4a84d38 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -241,6 +241,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) + System.exit(0) } private def printUsageAndExit() = { From 08c781ca672820be9ba32838bbe40d2643c4bde4 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 26 Jan 2016 07:50:37 -0800 Subject: [PATCH 020/131] [SPARK-12682][SQL] Add support for (optionally) not storing tables in hive metadata format This PR adds a new table option (`skip_hive_metadata`) that'd allow the user to skip storing the table metadata in hive metadata format. While this could be useful in general, the specific use-case for this change is that Hive doesn't handle wide schemas well (see https://issues.apache.org/jira/browse/SPARK-12682 and https://issues.apache.org/jira/browse/SPARK-6024) which in turn prevents such tables from being queried in SparkSQL. Author: Sameer Agarwal Closes #10826 from sameeragarwal/skip-hive-metadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 ++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0cfe03ba91ec..80e45d516280 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -327,7 +327,14 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // TODO: Support persisting partitioned data source relations in Hive compatible format val qualifiedTableName = tableIdent.quotedString + val skipHiveMetadata = options.getOrElse("skipHiveMetadata", "false").toBoolean val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.relation) match { + case _ if skipHiveMetadata => + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 211932fea00e..d9e4b020fdfc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -900,4 +900,36 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sqlContext.sql("""use default""") sqlContext.sql("""drop database if exists testdb8156 CASCADE""") } + + test("skip hive metadata on table creation") { + val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) + + catalog.createDataSourceTable( + tableIdent = TableIdentifier("not_skip_hive_metadata"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "parquet", + options = Map("path" -> "just a dummy path", "skipHiveMetadata" -> "false"), + isExternal = false) + + // As a proxy for verifying that the table was stored in Hive compatible format, we verify that + // each column of the table is of native type StringType. + assert(catalog.client.getTable("default", "not_skip_hive_metadata").schema + .forall(column => HiveMetastoreTypes.toDataType(column.hiveType) == StringType)) + + catalog.createDataSourceTable( + tableIdent = TableIdentifier("skip_hive_metadata"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "parquet", + options = Map("path" -> "just a dummy path", "skipHiveMetadata" -> "true"), + isExternal = false) + + // As a proxy for verifying that the table was stored in SparkSQL format, we verify that + // the table has a column type as array of StringType. + assert(catalog.client.getTable("default", "skip_hive_metadata").schema + .forall(column => HiveMetastoreTypes.toDataType(column.hiveType) == ArrayType(StringType))) + } } From cbd507d69cea24adfb335d8fe26ab5a13c053ffc Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 26 Jan 2016 11:31:54 -0800 Subject: [PATCH 021/131] [SPARK-7799][STREAMING][DOCUMENT] Add the linking and deploying instructions for streaming-akka project Since `actorStream` is an external project, we should add the linking and deploying instructions for it. A follow up PR of #10744 Author: Shixiong Zhu Closes #10856 from zsxwing/akka-link-instruction. --- docs/streaming-custom-receivers.md | 81 ++++++++++++++++-------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 95b99862ec06..84547748618d 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -257,54 +257,61 @@ The following table summarizes the characteristics of both types of receivers ## Implementing and Using a Custom Actor-based Receiver -

-
- Custom [Akka Actors](http://doc.akka.io/docs/akka/2.3.11/scala/actors.html) can also be used to -receive data. Extending [`ActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.ActorReceiver) -allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of -this actor can be configured to handle failures, etc. +receive data. Here are the instructions. -{% highlight scala %} +1. **Linking:** You need to add the following dependency to your SBT or Maven project (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). -class CustomActor extends ActorReceiver { - def receive = { - case data: String => store(data) - } -} + groupId = org.apache.spark + artifactId = spark-streaming-akka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} -// A new input stream can be created with this custom actor as -val ssc: StreamingContext = ... -val lines = AkkaUtils.createStream[String](ssc, Props[CustomActor](), "CustomReceiver") +2. **Programming:** -{% endhighlight %} +
+
-See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. -
-
+ You need to extend [`ActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.ActorReceiver) + so as to store received data into Spark using `store(...)` methods. The supervisor strategy of + this actor can be configured to handle failures, etc. -Custom [Akka UntypedActors](http://doc.akka.io/docs/akka/2.3.11/java/untyped-actors.html) can also be used to -receive data. Extending [`JavaActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.JavaActorReceiver) -allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of -this actor can be configured to handle failures, etc. + class CustomActor extends ActorReceiver { + def receive = { + case data: String => store(data) + } + } -{% highlight java %} + // A new input stream can be created with this custom actor as + val ssc: StreamingContext = ... + val lines = AkkaUtils.createStream[String](ssc, Props[CustomActor](), "CustomReceiver") -class CustomActor extends JavaActorReceiver { - @Override - public void onReceive(Object msg) throws Exception { - store((String) msg); - } -} + See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. +
+
-// A new input stream can be created with this custom actor as -JavaStreamingContext jssc = ...; -JavaDStream lines = AkkaUtils.createStream(jssc, Props.create(CustomActor.class), "CustomReceiver"); + You need to extend [`JavaActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.JavaActorReceiver) + so as to store received data into Spark using `store(...)` methods. The supervisor strategy of + this actor can be configured to handle failures, etc. -{% endhighlight %} + class CustomActor extends JavaActorReceiver { + @Override + public void onReceive(Object msg) throws Exception { + store((String) msg); + } + } -See [JavaActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaActorWordCount.scala) for an end-to-end example. -
-
+ // A new input stream can be created with this custom actor as + JavaStreamingContext jssc = ...; + JavaDStream lines = AkkaUtils.createStream(jssc, Props.create(CustomActor.class), "CustomReceiver"); + + See [JavaActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaActorWordCount.scala) for an end-to-end example. +
+
+ +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. +You need to package `spark-streaming-akka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into +the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` +are marked as `provided` dependencies as those are already present in a Spark installation. Then +use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). Python API Since actors are available only in the Java and Scala libraries, AkkaUtils is not available in the Python API. From 8beab68152348c44cf2f89850f792f164b06470d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 11:56:46 -0800 Subject: [PATCH 022/131] [SPARK-11923][ML] Python API for ml.feature.ChiSqSelector https://issues.apache.org/jira/browse/SPARK-11923 Author: Xusen Yin Closes #10186 from yinxusen/SPARK-11923. --- python/pyspark/ml/feature.py | 98 +++++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f139d81bc490..32f324685a9c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -33,7 +33,8 @@ 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', + 'ChiSqSelector', 'ChiSqSelectorModel'] @inherit_doc @@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel): """ +@inherit_doc +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol): + """ + .. note:: Experimental + + Chi-Squared feature selection, which selects categorical features to use for predicting a + categorical label. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], + ... ["features", "label"]) + >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.transform(df).head().selectedFeatures + DenseVector([1.0]) + >>> model.selectedFeatures + [3] + + .. versionadded:: 2.0.0 + """ + + # a placeholder to make it appear in the generated doc + numTopFeatures = \ + Param(Params._dummy(), "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will select " + + "all features.") + + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + """ + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + """ + super(ChiSqSelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) + self.numTopFeatures = \ + Param(self, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will " + + "select all features.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels"): + """ + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ + labelCol="labels") + Sets params for this ChiSqSelector. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumTopFeatures(self, value): + """ + Sets the value of :py:attr:`numTopFeatures`. + """ + self._paramMap[self.numTopFeatures] = value + return self + + @since("2.0.0") + def getNumTopFeatures(self): + """ + Gets the value of numTopFeatures or its default value. + """ + return self.getOrDefault(self.numTopFeatures) + + def _create_model(self, java_model): + return ChiSqSelectorModel(java_model) + + +class ChiSqSelectorModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by ChiSqSelector. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def selectedFeatures(self): + """ + List of indices to select (filter). Must be ordered asc. + """ + return self._call_java("selectedFeatures") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From fbf7623d49525e3aa6b08f482afd7ee8118d80cb Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 13:18:01 -0800 Subject: [PATCH 023/131] [SPARK-12952] EMLDAOptimizer initialize() should return EMLDAOptimizer other than its parent class https://issues.apache.org/jira/browse/SPARK-12952 Author: Xusen Yin Closes #10863 from yinxusen/SPARK-12952. --- .../org/apache/spark/mllib/clustering/LDAOptimizer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index c19595e6cd21..7a41f7419153 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -95,7 +95,9 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * Compute bipartite term/doc graph. */ - override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { + override private[clustering] def initialize( + docs: RDD[(Long, Vector)], + lda: LDA): EMLDAOptimizer = { // EMLDAOptimizer currently only supports symmetric document-topic priors val docConcentration = lda.getDocConcentration From ee74498de372b16fe6350e3617e9e6ec87c6ae7b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 26 Jan 2016 14:20:11 -0800 Subject: [PATCH 024/131] [SPARK-8725][PROJECT-INFRA] Test modules in topologically-sorted order in dev/run-tests This patch improves our `dev/run-tests` script to test modules in a topologically-sorted order based on modules' dependencies. This will help to ensure that bugs in upstream projects are not misattributed to downstream projects because those projects' tests were the first ones to exhibit the failure Topological sorting is also useful for shortening the feedback loop when testing pull requests: if I make a change in SQL then the SQL tests should run before MLlib, not after. In addition, this patch also updates our test module definitions to split `sql` into `catalyst`, `sql`, and `hive` in order to allow more tests to be skipped when changing only `hive/` files. Author: Josh Rosen Closes #10885 from JoshRosen/SPARK-8725. --- NOTICE | 16 ++++++ dev/run-tests.py | 25 ++++++---- dev/sparktestsupport/modules.py | 54 +++++++++++++++++--- dev/sparktestsupport/toposort.py | 85 ++++++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 18 deletions(-) create mode 100644 dev/sparktestsupport/toposort.py diff --git a/NOTICE b/NOTICE index e416aadce991..6a26155fb495 100644 --- a/NOTICE +++ b/NOTICE @@ -650,3 +650,19 @@ For CSV functionality: */ +=============================================================================== +For dev/sparktestsupport/toposort.py: + +Copyright 2014 True Blade Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/dev/run-tests.py b/dev/run-tests.py index 8f47728f206c..c78a66f6aa54 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -29,6 +29,7 @@ from sparktestsupport import SPARK_HOME, USER_HOME, ERROR_CODES from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +from sparktestsupport.toposort import toposort_flatten, toposort import sparktestsupport.modules as modules @@ -43,7 +44,7 @@ def determine_modules_for_files(filenames): If a file is not associated with a more specific submodule, then this method will consider that file to belong to the 'root' module. - >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) + >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"])) ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] @@ -99,14 +100,16 @@ def determine_modules_to_test(changed_modules): Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. - >>> sorted(x.name for x in determine_modules_to_test([modules.root])) + Returns a topologically-sorted list of modules (ties are broken by sorting on module names). + + >>> [x.name for x in determine_modules_to_test([modules.root])] ['root'] - >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) - ['examples', 'graphx'] - >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> [x.name for x in determine_modules_to_test([modules.graphx])] + ['graphx', 'examples'] + >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] + ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', + 'pyspark-mllib', 'pyspark-ml'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be @@ -116,7 +119,9 @@ def determine_modules_to_test(changed_modules): modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) - return modules_to_test.union(set(changed_modules)) + modules_to_test = modules_to_test.union(set(changed_modules)) + return toposort_flatten( + {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True) def determine_tags_to_exclude(changed_modules): @@ -377,12 +382,12 @@ def run_scala_tests_maven(test_profiles): def run_scala_tests_sbt(test_modules, test_profiles): - sbt_test_goals = set(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) + sbt_test_goals = list(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) if not sbt_test_goals: return - profiles_and_goals = test_profiles + list(sbt_test_goals) + profiles_and_goals = test_profiles + sbt_test_goals print("[info] Running Spark tests using SBT with these arguments: ", " ".join(profiles_and_goals)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 032c0616edb1..07c3078e4549 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -15,12 +15,14 @@ # limitations under the License. # +from functools import total_ordering import itertools import re all_modules = [] +@total_ordering class Module(object): """ A module is the basic abstraction in our test runner script. Each module consists of a set of @@ -75,20 +77,56 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= def contains_file(self, filename): return any(re.match(p, filename) for p in self.source_file_prefixes) + def __repr__(self): + return "Module<%s>" % self.name + + def __lt__(self, other): + return self.name < other.name + + def __eq__(self, other): + return self.name == other.name + + def __ne__(self, other): + return not (self.name == other.name) + + def __hash__(self): + return hash(self.name) + + +catalyst = Module( + name="catalyst", + dependencies=[], + source_file_regexes=[ + "sql/catalyst/", + ], + sbt_test_goals=[ + "catalyst/test", + ], +) + sql = Module( name="sql", - dependencies=[], + dependencies=[catalyst], source_file_regexes=[ - "sql/(?!hive-thriftserver)", + "sql/core/", + ], + sbt_test_goals=[ + "sql/test", + ], +) + +hive = Module( + name="hive", + dependencies=[sql], + source_file_regexes=[ + "sql/hive/", "bin/spark-sql", ], build_profile_flags=[ "-Phive", ], sbt_test_goals=[ - "catalyst/test", - "sql/test", "hive/test", ], test_tags=[ @@ -99,7 +137,7 @@ def contains_file(self, filename): hive_thriftserver = Module( name="hive-thriftserver", - dependencies=[sql], + dependencies=[hive], source_file_regexes=[ "sql/hive-thriftserver", "sbin/start-thriftserver.sh", @@ -282,7 +320,7 @@ def contains_file(self, filename): examples = Module( name="examples", - dependencies=[graphx, mllib, streaming, sql], + dependencies=[graphx, mllib, streaming, hive], source_file_regexes=[ "examples/", ], @@ -314,7 +352,7 @@ def contains_file(self, filename): pyspark_sql = Module( name="pyspark-sql", - dependencies=[pyspark_core, sql], + dependencies=[pyspark_core, hive], source_file_regexes=[ "python/pyspark/sql" ], @@ -404,7 +442,7 @@ def contains_file(self, filename): sparkr = Module( name="sparkr", - dependencies=[sql, mllib], + dependencies=[hive, mllib], source_file_regexes=[ "R/", ], diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py new file mode 100644 index 000000000000..6c67b4504bc3 --- /dev/null +++ b/dev/sparktestsupport/toposort.py @@ -0,0 +1,85 @@ +####################################################################### +# Implements a topological sort algorithm. +# +# Copyright 2014 True Blade Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Notes: +# Based on http://code.activestate.com/recipes/578272-topological-sort +# with these major changes: +# Added unittests. +# Deleted doctests (maybe not the best idea in the world, but it cleans +# up the docstring). +# Moved functools import to the top of the file. +# Changed assert to a ValueError. +# Changed iter[items|keys] to [items|keys], for python 3 +# compatibility. I don't think it matters for python 2 these are +# now lists instead of iterables. +# Copy the input so as to leave it unmodified. +# Renamed function from toposort2 to toposort. +# Handle empty input. +# Switch tests to use set literals. +# +######################################################################## + +from functools import reduce as _reduce + + +__all__ = ['toposort', 'toposort_flatten'] + + +def toposort(data): + """Dependencies are expressed as a dictionary whose keys are items +and whose values are a set of dependent items. Output is a list of +sets in topological order. The first set consists of items with no +dependences, each subsequent set consists of items that depend upon +items in the preceeding sets. +""" + + # Special case empty input. + if len(data) == 0: + return + + # Copy the input so as to leave it unmodified. + data = data.copy() + + # Ignore self dependencies. + for k, v in data.items(): + v.discard(k) + # Find all items that don't depend on anything. + extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) + # Add empty dependences where needed. + data.update({item: set() for item in extra_items_in_deps}) + while True: + ordered = set(item for item, dep in data.items() if len(dep) == 0) + if not ordered: + break + yield ordered + data = {item: (dep - ordered) + for item, dep in data.items() + if item not in ordered} + if len(data) != 0: + raise ValueError('Cyclic dependencies exist among these items: {}'.format( + ', '.join(repr(x) for x in data.items()))) + + +def toposort_flatten(data, sort=True): + """Returns a single list of dependencies. For any set returned by +toposort(), those items are sorted and appended to the result (just to +make the results deterministic).""" + + result = [] + for d in toposort(data): + result.extend((sorted if sort else list)(d)) + return result From 83507fea9f45c336d73dd4795b8cb37bcd63e31d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 14:29:29 -0800 Subject: [PATCH 025/131] [SQL] Minor Scaladoc format fix Otherwise the `^` character is always marked as error in IntelliJ since it represents an unclosed superscript markup tag. Author: Cheng Lian Closes #10926 from liancheng/agg-doc-fix. --- .../sql/catalyst/expressions/aggregate/interfaces.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ddd99c51ab0c..561fa3321d8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -200,7 +200,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)` * will be 0 and the position of the first buffer value of `avg(y)` will be 2: - * + * {{{ * avg(x) mutableAggBufferOffset = 0 * | * v @@ -210,7 +210,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * ^ * | * avg(y) mutableAggBufferOffset = 2 - * + * }}} */ protected val mutableAggBufferOffset: Int @@ -233,7 +233,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` * will be 3 (position 0 is used for the value of `key`): - * + * {{{ * avg(x) inputAggBufferOffset = 1 * | * v @@ -243,7 +243,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * ^ * | * avg(y) inputAggBufferOffset = 3 - * + * }}} */ protected val inputAggBufferOffset: Int From 19fdb21afbf0eae4483cf6d4ef32daffd1994b89 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 26 Jan 2016 14:58:39 -0800 Subject: [PATCH 026/131] [SPARK-12993][PYSPARK] Remove usage of ADD_FILES in pyspark environment variable ADD_FILES is created for adding python files on spark context to be distributed to executors (SPARK-865), this is deprecated now. User are encouraged to use --py-files for adding python files. Author: Jeff Zhang Closes #10913 from zjffdu/SPARK-12993. --- python/pyspark/shell.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 26cafca8b838..7c37f7519347 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -32,15 +32,10 @@ from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the deprecated equivalent of ADD_JARS -add_files = None -if os.environ.get("ADD_FILES") is not None: - add_files = os.environ.get("ADD_FILES").split(',') - if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext(pyFiles=add_files) +sc = SparkContext() atexit.register(lambda: sc.stop()) try: @@ -68,10 +63,6 @@ platform.python_build()[1])) print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) -if add_files is not None: - print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") - print("Adding files: [%s]" % ", ".join(add_files)) - # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') From eb917291ca1a2d68ca0639cb4b1464a546603eba Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 26 Jan 2016 15:53:48 -0800 Subject: [PATCH 027/131] [SPARK-10509][PYSPARK] Reduce excessive param boiler plate code The current python ml params require cut-and-pasting the param setup and description between the class & ```__init__``` methods. Remove this possible case of errors & simplify use of custom params by adding a ```_copy_new_parent``` method to param so as to avoid cut and pasting (and cut and pasting at different indentation levels urgh). Author: Holden Karau Closes #10216 from holdenk/SPARK-10509-excessive-param-boiler-plate-code. --- python/pyspark/ml/classification.py | 32 ------ python/pyspark/ml/clustering.py | 7 -- python/pyspark/ml/evaluation.py | 12 --- python/pyspark/ml/feature.py | 98 +------------------ python/pyspark/ml/param/__init__.py | 22 +++++ .../ml/param/_shared_params_code_gen.py | 17 +--- python/pyspark/ml/param/shared.py | 81 +-------------- python/pyspark/ml/pipeline.py | 4 +- python/pyspark/ml/recommendation.py | 11 --- python/pyspark/ml/regression.py | 46 --------- python/pyspark/ml/tests.py | 12 +++ python/pyspark/ml/tuning.py | 18 ---- 12 files changed, 43 insertions(+), 317 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 265c6a14f1ca..3179fb30ab4d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -72,7 +72,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti .. versionadded:: 1.3.0 """ - # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -92,10 +91,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for threshold in binary classification, in range [0, 1]. - self.threshold = Param(self, "threshold", - "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.") self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -232,7 +227,6 @@ class TreeClassifierParams(object): """ supportedImpurities = ["entropy", "gini"] - # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + @@ -240,10 +234,6 @@ class TreeClassifierParams(object): def __init__(self): super(TreeClassifierParams, self).__init__() - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = Param(self, "impurity", "Criterion used for information " + - "gain calculation (case-insensitive). Supported options: " + - ", ".join(self.supportedImpurities)) @since("1.6.0") def setImpurity(self, value): @@ -485,7 +475,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) @@ -504,10 +493,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) - #: param for Loss function which GBT tries to minimize (case-insensitive). - self.lossType = Param(self, "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1) @@ -597,7 +582,6 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + "default is 1.0") modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + @@ -615,13 +599,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid) - #: param for the smoothing parameter. - self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + - "default is 1.0") - #: param for the model type. - self.modelType = Param(self, "modelType", "The model type which is a string " + - "(case-sensitive). Supported options: multinomial (default) " + - "and bernoulli.") self._setDefault(smoothing=1.0, modelType="multinomial") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -734,7 +711,6 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + "neurons and output layer of 10 neurons, default is [1, 1].") @@ -753,14 +729,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + - "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + - "100 neurons and output layer of 10 neurons, default is [1, 1].") - self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + - "matrices. Data is stacked within partitions. If block size is " + - "more than remaining data in a partition then it is adjusted to " + - "the size of this data. Recommended size is between 10 and 1000, " + - "default is 128.") self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 9189c0222022..60d1c9aaec98 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -73,7 +73,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + @@ -90,12 +89,6 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self.k = Param(self, "k", "number of clusters to create") - self.initMode = Param(self, "initMode", - "the initialization algorithm. This can be either \"random\" to " + - "choose random points as initial cluster centers, or \"k-means||\" " + - "to use a parallel variant of k-means++") - self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 6ff68abd8f18..c9b95b3bf45d 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -124,7 +124,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -138,9 +137,6 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", super(BinaryClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) - #: param for metric name in evaluation (areaUnderROC|areaUnderPR) - self.metricName = Param(self, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") kwargs = self.__init__._input_kwargs @@ -210,9 +206,6 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(RegressionEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) - #: param for metric name in evaluation (mse|rmse|r2|mae) - self.metricName = Param(self, "metricName", - "metric name in evaluation (mse|rmse|r2|mae)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="rmse") kwargs = self.__init__._input_kwargs @@ -265,7 +258,6 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation " "(f1|precision|recall|weightedPrecision|weightedRecall)") @@ -280,10 +272,6 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) - # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) - self.metricName = Param(self, "metricName", - "metric name in evaluation" - " (f1|precision|recall|weightedPrecision|weightedRecall)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") kwargs = self.__init__._input_kwargs diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 32f324685a9c..22081233b04d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -57,7 +57,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]") @@ -68,8 +67,6 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): """ super(Binarizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) - self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1]") self._setDefault(threshold=0.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -125,7 +122,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.3.0 """ - # a placeholder to make it appear in the generated doc splits = \ Param(Params._dummy(), "splits", "Split points for mapping continuous features into buckets. With n+1 splits, " + @@ -142,19 +138,6 @@ def __init__(self, splits=None, inputCol=None, outputCol=None): """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) - #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, - # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) - # except the last bucket, which also includes y. The splits should be strictly increasing. - # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, - # values outside the splits specified will be treated as errors. - self.splits = \ - Param(self, "splits", - "Split points for mapping continuous features into buckets. With n+1 splits, " + - "there are n buckets. A bucket defined by splits x,y holds values in the " + - "range [x,y) except the last bucket, which also includes y. The splits " + - "should be strictly increasing. Values at -inf, inf must be explicitly " + - "provided to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -210,7 +193,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc minTF = Param( Params._dummy(), "minTF", "Filter to ignore rare words in" + " a document. For each document, terms with frequency/count less than the given" + @@ -235,22 +217,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outpu super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self.minTF = Param( - self, "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then " + - "this specifies a fraction (out of the document's token count). Note that the " + - "parameter is only used in transform of CountVectorizerModel and does not affect" + - "fitting. Default 1.0") - self.minDF = Param( - self, "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of " + - "documents. Default 1.0") - self.vocabSize = Param( - self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -359,7 +325,6 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + "default False.") @@ -370,8 +335,6 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): """ super(DCT, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) - self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " + - "default False.") self._setDefault(inverse=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -423,7 +386,6 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + "it must be MLlib Vector type.") @@ -435,8 +397,6 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): super(ElementwiseProduct, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", self.uid) - self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " + - "it must be MLlib Vector type.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -531,7 +491,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc minDocFreq = Param(Params._dummy(), "minDocFreq", "minimum of documents in which a term should appear for filtering") @@ -542,8 +501,6 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): """ super(IDF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) - self.minDocFreq = Param(self, "minDocFreq", - "minimum of documents in which a term should appear for filtering") self._setDefault(minDocFreq=0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -623,7 +580,6 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc min = Param(Params._dummy(), "min", "Lower bound of the output feature range") max = Param(Params._dummy(), "max", "Upper bound of the output feature range") @@ -634,8 +590,6 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): """ super(MinMaxScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) - self.min = Param(self, "min", "Lower bound of the output feature range") - self.max = Param(self, "max", "Upper bound of the output feature range") self._setDefault(min=0.0, max=1.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -745,7 +699,6 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") @keyword_only @@ -755,7 +708,6 @@ def __init__(self, n=2, inputCol=None, outputCol=None): """ super(NGram, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) - self.n = Param(self, "n", "number of elements per n-gram (>=1)") self._setDefault(n=2) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -808,7 +760,6 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc p = Param(Params._dummy(), "p", "the p norm value.") @keyword_only @@ -818,7 +769,6 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): """ super(Normalizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) - self.p = Param(self, "p", "the p norm value.") self._setDefault(p=2.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -887,7 +837,6 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only @@ -897,7 +846,6 @@ def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self.dropLast = Param(self, "dropLast", "whether to drop the last category") self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -950,7 +898,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") @keyword_only @@ -961,7 +908,6 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): super(PolynomialExpansion, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) - self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") self._setDefault(degree=2) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1107,7 +1053,6 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") @@ -1123,11 +1068,6 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, """ super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) - self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") - self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") - self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") - self.toLowercase = Param(self, "toLowercase", "whether to convert all characters to " + - "lowercase before tokenizing") self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1223,7 +1163,6 @@ class SQLTransformer(JavaTransformer): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc statement = Param(Params._dummy(), "statement", "SQL statement") @keyword_only @@ -1233,7 +1172,6 @@ def __init__(self, statement=None): """ super(SQLTransformer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) - self.statement = Param(self, "statement", "SQL statement") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1285,7 +1223,6 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc withMean = Param(Params._dummy(), "withMean", "Center data with mean") withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") @@ -1296,8 +1233,6 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): """ super(StandardScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) - self.withMean = Param(self, "withMean", "Center data with mean") - self.withStd = Param(self, "withStd", "Scale to unit standard deviation") self._setDefault(withMean=False, withStd=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1453,7 +1388,6 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make the labels show up in generated doc labels = Param(Params._dummy(), "labels", "Optional array of labels specifying index-string mapping." + " If not provided or if empty, then metadata from inputCol is used instead.") @@ -1466,9 +1400,6 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): super(IndexToString, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) - self.labels = Param(self, "labels", - "Optional array of labels specifying index-string mapping. If not" + - " provided or if empty, then metadata from inputCol is used instead.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1507,7 +1438,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words") @@ -1522,9 +1453,6 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) - self.stopWords = Param(self, "stopWords", "The words to be filtered out") - self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + - "sensitive comparison over the stop words") stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords defaultStopWords = stopWordsObj.English() self._setDefault(stopWords=defaultStopWords) @@ -1727,7 +1655,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + "(>= 2). If a feature is found to have > maxCategories values, then " + @@ -1740,10 +1667,6 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self.maxCategories = Param(self, "maxCategories", - "Threshold for the number of values a categorical feature " + - "can take (>= 2). If a feature is found to have " + - "> maxCategories values, then it is declared continuous.") self._setDefault(maxCategories=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1832,7 +1755,6 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + "a vector column. There can be no overlap with names.") names = Param(Params._dummy(), "names", "An array of feature names to select features from " + @@ -1847,12 +1769,6 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): """ super(VectorSlicer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) - self.indices = Param(self, "indices", "An array of indices to select features from " + - "a vector column. There can be no overlap with names.") - self.names = Param(self, "names", "An array of feature names to select features from " + - "a vector column. These names must be specified by ML " + - "org.apache.spark.ml.attribute.Attribute. There can be no overlap " + - "with indices.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1932,7 +1848,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc vectorSize = Param(Params._dummy(), "vectorSize", "the dimension of codes after transforming from words") numPartitions = Param(Params._dummy(), "numPartitions", @@ -1950,13 +1865,6 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, """ super(Word2Vec, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) - self.vectorSize = Param(self, "vectorSize", - "the dimension of codes after transforming from words") - self.numPartitions = Param(self, "numPartitions", - "number of partitions for sentences of words") - self.minCount = Param(self, "minCount", - "the minimum number of times a token must appear to be included " + - "in the word2vec model's vocabulary") self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None) kwargs = self.__init__._input_kwargs @@ -2075,7 +1983,6 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "the number of principal components") @keyword_only @@ -2085,7 +1992,6 @@ def __init__(self, k=None, inputCol=None, outputCol=None): """ super(PCA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) - self.k = Param(self, "k", "the number of principal components") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -2185,7 +2091,6 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc formula = Param(Params._dummy(), "formula", "R model formula") @keyword_only @@ -2195,7 +2100,6 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label"): """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) - self.formula = Param(self, "formula", "R model formula") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 92ce96aa3c4d..3da36d32c5af 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -40,6 +40,15 @@ def __init__(self, parent, name, doc, expectedType=None): self.doc = str(doc) self.expectedType = expectedType + def _copy_new_parent(self, parent): + """Copy the current param to a new parent, must be a dummy param.""" + if self.parent == "undefined": + param = copy.copy(self) + param.parent = parent.uid + return param + else: + raise ValueError("Cannot copy from non-dummy parent %s." % parent) + def __str__(self): return str(self.parent) + "__" + self.name @@ -77,6 +86,19 @@ def __init__(self): #: value returned by :py:func:`params` self._params = None + # Copy the params from the class to the object + self._copy_params() + + def _copy_params(self): + """ + Copy all params defined on the class to current object. + """ + cls = type(self) + src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)] + src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs)) + for name, param in src_params: + setattr(self, name, param._copy_new_parent(self)) + @property @since("1.3.0") def params(self): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 82855bc4c75b..5e297b821482 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -50,13 +50,11 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType): Mixin for param $name: $doc """ - # a placeholder to make it appear in the generated doc $name = Param(Params._dummy(), "$name", "$doc", $expectedType) def __init__(self): - super(Has$Name, self).__init__() - #: param for $doc - self.$name = Param(self, "$name", "$doc", $expectedType)''' + super(Has$Name, self).__init__()''' + if defaultValueStr is not None: template += ''' self._setDefault($name=$defaultValueStr)''' @@ -171,22 +169,17 @@ def get$Name(self): Mixin for Decision Tree parameters. """ - # a placeholder to make it appear in the generated doc $dummyPlaceHolders def __init__(self): - super(DecisionTreeParams, self).__init__() - $realParams''' + super(DecisionTreeParams, self).__init__()''' dtParamMethods = "" dummyPlaceholders = "" - realParams = "" paramTemplate = """$name = Param($owner, "$name", "$doc")""" for name, doc in decisionTreeParams: variable = paramTemplate.replace("$name", name).replace("$doc", doc) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " - realParams += "#: param for " + doc + "\n " - realParams += "self." + variable.replace("$owner", "self") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" - code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) - .replace("$realParams", realParams) + dtParamMethods) + code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" + + dtParamMethods) print("\n\n\n".join(code)) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 23f94314844f..db4a8a54d495 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -25,13 +25,10 @@ class HasMaxIter(Params): Mixin for param maxIter: max number of iterations (>= 0). """ - # a placeholder to make it appear in the generated doc maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", int) def __init__(self): super(HasMaxIter, self).__init__() - #: param for max number of iterations (>= 0). - self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).", int) def setMaxIter(self, value): """ @@ -52,13 +49,10 @@ class HasRegParam(Params): Mixin for param regParam: regularization parameter (>= 0). """ - # a placeholder to make it appear in the generated doc regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", float) def __init__(self): super(HasRegParam, self).__init__() - #: param for regularization parameter (>= 0). - self.regParam = Param(self, "regParam", "regularization parameter (>= 0).", float) def setRegParam(self, value): """ @@ -79,13 +73,10 @@ class HasFeaturesCol(Params): Mixin for param featuresCol: features column name. """ - # a placeholder to make it appear in the generated doc featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", str) def __init__(self): super(HasFeaturesCol, self).__init__() - #: param for features column name. - self.featuresCol = Param(self, "featuresCol", "features column name.", str) self._setDefault(featuresCol='features') def setFeaturesCol(self, value): @@ -107,13 +98,10 @@ class HasLabelCol(Params): Mixin for param labelCol: label column name. """ - # a placeholder to make it appear in the generated doc labelCol = Param(Params._dummy(), "labelCol", "label column name.", str) def __init__(self): super(HasLabelCol, self).__init__() - #: param for label column name. - self.labelCol = Param(self, "labelCol", "label column name.", str) self._setDefault(labelCol='label') def setLabelCol(self, value): @@ -135,13 +123,10 @@ class HasPredictionCol(Params): Mixin for param predictionCol: prediction column name. """ - # a placeholder to make it appear in the generated doc predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", str) def __init__(self): super(HasPredictionCol, self).__init__() - #: param for prediction column name. - self.predictionCol = Param(self, "predictionCol", "prediction column name.", str) self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): @@ -163,13 +148,10 @@ class HasProbabilityCol(Params): Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. """ - # a placeholder to make it appear in the generated doc probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) def __init__(self): super(HasProbabilityCol, self).__init__() - #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. - self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) self._setDefault(probabilityCol='probability') def setProbabilityCol(self, value): @@ -191,13 +173,10 @@ class HasRawPredictionCol(Params): Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. """ - # a placeholder to make it appear in the generated doc rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) def __init__(self): super(HasRawPredictionCol, self).__init__() - #: param for raw prediction (a.k.a. confidence) column name. - self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): @@ -219,13 +198,10 @@ class HasInputCol(Params): Mixin for param inputCol: input column name. """ - # a placeholder to make it appear in the generated doc inputCol = Param(Params._dummy(), "inputCol", "input column name.", str) def __init__(self): super(HasInputCol, self).__init__() - #: param for input column name. - self.inputCol = Param(self, "inputCol", "input column name.", str) def setInputCol(self, value): """ @@ -246,13 +222,10 @@ class HasInputCols(Params): Mixin for param inputCols: input column names. """ - # a placeholder to make it appear in the generated doc inputCols = Param(Params._dummy(), "inputCols", "input column names.", None) def __init__(self): super(HasInputCols, self).__init__() - #: param for input column names. - self.inputCols = Param(self, "inputCols", "input column names.", None) def setInputCols(self, value): """ @@ -273,13 +246,10 @@ class HasOutputCol(Params): Mixin for param outputCol: output column name. """ - # a placeholder to make it appear in the generated doc outputCol = Param(Params._dummy(), "outputCol", "output column name.", str) def __init__(self): super(HasOutputCol, self).__init__() - #: param for output column name. - self.outputCol = Param(self, "outputCol", "output column name.", str) self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): @@ -301,13 +271,10 @@ class HasNumFeatures(Params): Mixin for param numFeatures: number of features. """ - # a placeholder to make it appear in the generated doc numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", int) def __init__(self): super(HasNumFeatures, self).__init__() - #: param for number of features. - self.numFeatures = Param(self, "numFeatures", "number of features.", int) def setNumFeatures(self, value): """ @@ -328,13 +295,10 @@ class HasCheckpointInterval(Params): Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ - # a placeholder to make it appear in the generated doc checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. - self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def setCheckpointInterval(self, value): """ @@ -355,13 +319,10 @@ class HasSeed(Params): Mixin for param seed: random seed. """ - # a placeholder to make it appear in the generated doc seed = Param(Params._dummy(), "seed", "random seed.", int) def __init__(self): super(HasSeed, self).__init__() - #: param for random seed. - self.seed = Param(self, "seed", "random seed.", int) self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): @@ -383,13 +344,10 @@ class HasTol(Params): Mixin for param tol: the convergence tolerance for iterative algorithms. """ - # a placeholder to make it appear in the generated doc tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", float) def __init__(self): super(HasTol, self).__init__() - #: param for the convergence tolerance for iterative algorithms. - self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.", float) def setTol(self, value): """ @@ -410,13 +368,10 @@ class HasStepSize(Params): Mixin for param stepSize: Step size to be used for each iteration of optimization. """ - # a placeholder to make it appear in the generated doc stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", float) def __init__(self): super(HasStepSize, self).__init__() - #: param for Step size to be used for each iteration of optimization. - self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.", float) def setStepSize(self, value): """ @@ -437,13 +392,10 @@ class HasHandleInvalid(Params): Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. """ - # a placeholder to make it appear in the generated doc handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def __init__(self): super(HasHandleInvalid, self).__init__() - #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. - self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def setHandleInvalid(self, value): """ @@ -464,13 +416,10 @@ class HasElasticNetParam(Params): Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. """ - # a placeholder to make it appear in the generated doc elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) def __init__(self): super(HasElasticNetParam, self).__init__() - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) self._setDefault(elasticNetParam=0.0) def setElasticNetParam(self, value): @@ -492,13 +441,10 @@ class HasFitIntercept(Params): Mixin for param fitIntercept: whether to fit an intercept term. """ - # a placeholder to make it appear in the generated doc fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", bool) def __init__(self): super(HasFitIntercept, self).__init__() - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.", bool) self._setDefault(fitIntercept=True) def setFitIntercept(self, value): @@ -520,13 +466,10 @@ class HasStandardization(Params): Mixin for param standardization: whether to standardize the training features before fitting the model. """ - # a placeholder to make it appear in the generated doc standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", bool) def __init__(self): super(HasStandardization, self).__init__() - #: param for whether to standardize the training features before fitting the model. - self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.", bool) self._setDefault(standardization=True) def setStandardization(self, value): @@ -548,13 +491,10 @@ class HasThresholds(Params): Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. """ - # a placeholder to make it appear in the generated doc thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def __init__(self): super(HasThresholds, self).__init__() - #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. - self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def setThresholds(self, value): """ @@ -575,13 +515,10 @@ class HasWeightCol(Params): Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. """ - # a placeholder to make it appear in the generated doc weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def __init__(self): super(HasWeightCol, self).__init__() - #: param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. - self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def setWeightCol(self, value): """ @@ -602,13 +539,10 @@ class HasSolver(Params): Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. """ - # a placeholder to make it appear in the generated doc solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) def __init__(self): super(HasSolver, self).__init__() - #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. - self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) self._setDefault(solver='auto') def setSolver(self, value): @@ -630,7 +564,6 @@ class DecisionTreeParams(Params): Mixin for Decision Tree parameters. """ - # a placeholder to make it appear in the generated doc maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") @@ -641,19 +574,7 @@ class DecisionTreeParams(Params): def __init__(self): super(DecisionTreeParams, self).__init__() - #: param for Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - self.maxDepth = Param(self, "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") - #: param for Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature. - self.maxBins = Param(self, "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") - #: param for Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1. - self.minInstancesPerNode = Param(self, "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") - #: param for Minimum information gain for a split to be considered at a tree node. - self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") - #: param for Maximum memory in MB allocated to histogram aggregation. - self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9f5f6ac8fa4e..661074ca9621 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -149,6 +149,8 @@ class Pipeline(Estimator): .. versionadded:: 1.3.0 """ + stages = Param(Params._dummy(), "stages", "pipeline stages") + @keyword_only def __init__(self, stages=None): """ @@ -157,8 +159,6 @@ def __init__(self, stages=None): if stages is None: stages = [] super(Pipeline, self).__init__() - #: Param for pipeline stages. - self.stages = Param(self, "stages", "pipeline stages") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b44c66f73cc4..08180a2f25eb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -85,7 +85,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc rank = Param(Params._dummy(), "rank", "rank of the factorization") numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks") @@ -108,16 +107,6 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) - self.rank = Param(self, "rank", "rank of the factorization") - self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") - self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") - self.implicitPrefs = Param(self, "implicitPrefs", "whether to use implicit preference") - self.alpha = Param(self, "alpha", "alpha for implicit preference") - self.userCol = Param(self, "userCol", "column name for user ids") - self.itemCol = Param(self, "itemCol", "column name for item ids") - self.ratingCol = Param(self, "ratingCol", "column name for ratings") - self.nonnegative = Param(self, "nonnegative", - "whether to use nonnegative constraint for least squares") self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 401bac0223eb..74a2248ed07c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -162,7 +162,6 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti DenseVector([0.0, 1.0]) """ - # a placeholder to make it appear in the generated doc isotonic = \ Param(Params._dummy(), "isotonic", "whether the output sequence should be isotonic/increasing (true) or" + @@ -181,14 +180,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(IsotonicRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid) - self.isotonic = \ - Param(self, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + - "antitonic/decreasing (false).") - self.featureIndex = \ - Param(self, "featureIndex", - "The index of the feature if featuresCol is a vector column, no effect " + - "otherwise.") self._setDefault(isotonic=True, featureIndex=0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -262,15 +253,11 @@ class TreeEnsembleParams(DecisionTreeParams): Mixin for Decision Tree-based ensemble algorithms parameters. """ - # a placeholder to make it appear in the generated doc subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " + "used for learning each decision tree, in range (0, 1].") def __init__(self): super(TreeEnsembleParams, self).__init__() - #: param for Fraction of the training data, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", "Fraction of the training data " + - "used for learning each decision tree, in range (0, 1].") @since("1.4.0") def setSubsamplingRate(self, value): @@ -294,7 +281,6 @@ class TreeRegressorParams(Params): """ supportedImpurities = ["variance"] - # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + @@ -302,10 +288,6 @@ class TreeRegressorParams(Params): def __init__(self): super(TreeRegressorParams, self).__init__() - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = Param(self, "impurity", "Criterion used for information " + - "gain calculation (case-insensitive). Supported options: " + - ", ".join(self.supportedImpurities)) @since("1.4.0") def setImpurity(self, value): @@ -329,7 +311,6 @@ class RandomForestParams(TreeEnsembleParams): """ supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] - # a placeholder to make it appear in the generated doc numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).") featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", @@ -338,13 +319,6 @@ class RandomForestParams(TreeEnsembleParams): def __init__(self): super(RandomForestParams, self).__init__() - #: param for Number of trees to train (>= 1). - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1).") - #: param for The number of features to consider for splits at each tree node. - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(self.supportedFeatureSubsetStrategies)) @since("1.4.0") def setNumTrees(self, value): @@ -609,7 +583,6 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) @@ -627,10 +600,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) - #: param for Loss function which GBT tries to minimize (case-insensitive). - self.lossType = Param(self, "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) @@ -713,7 +682,6 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc censorCol = Param(Params._dummy(), "censorCol", "censor column name. The value of this column could be 0 or 1. " + "If the value is 1, it means the event has occurred i.e. " + @@ -739,20 +707,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) - #: Param for censor column name - self.censorCol = Param(self, "censorCol", - "censor column name. The value of this column could be 0 or 1. " + - "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.") - #: Param for quantile probabilities array - self.quantileProbabilities = \ - Param(self, "quantileProbabilities", - "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.") - #: Param for quantiles column name - self.quantilesCol = Param(self, "quantilesCol", - "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.") self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) kwargs = self.__init__._input_kwargs diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9ea639dc4f96..c45a159c460f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -185,6 +185,18 @@ def setParams(self, seed=None): class ParamTests(PySparkTestCase): + def test_copy_new_parent(self): + testParams = TestParams() + # Copying an instantiated param should fail + with self.assertRaises(ValueError): + testParams.maxIter._copy_new_parent(testParams) + # Copying a dummy param should succeed + TestParams.maxIter._copy_new_parent(testParams) + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + def test_param(self): testParams = TestParams() maxIter = testParams.maxIter diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 08f8db57f440..0cbe97f1d839 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,18 +115,11 @@ class CrossValidator(Estimator, HasSeed): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") - - # a placeholder to make it appear in the generated doc estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - - # a placeholder to make it appear in the generated doc evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") - - # a placeholder to make it appear in the generated doc numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only @@ -137,17 +130,6 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF seed=None) """ super(CrossValidator, self).__init__() - #: param for estimator to be cross-validated - self.estimator = Param(self, "estimator", "estimator to be cross-validated") - #: param for estimator param maps - self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") - #: param for the evaluator used to select hyper-parameters that - #: maximize the cross-validated metric - self.evaluator = Param( - self, "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - #: param for number of folds for cross validation - self.numFolds = Param(self, "numFolds", "number of folds for cross validation") self._setDefault(numFolds=3) kwargs = self.__init__._input_kwargs self._set(**kwargs) From 22662b241629b56205719ede2f801a476e10a3cd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 26 Jan 2016 17:24:40 -0800 Subject: [PATCH 028/131] [SPARK-12614][CORE] Don't throw non fatal exception from ask Right now RpcEndpointRef.ask may throw exception in some corner cases, such as calling ask after stopping RpcEnv. It's better to avoid throwing exception from RpcEndpointRef.ask. We can send the exception to the future for `ask`. Author: Shixiong Zhu Closes #10568 from zsxwing/send-ask-fail. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 54 ++++++++++--------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index ef876b1d8c15..9ae74d9d7b89 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -211,33 +211,37 @@ private[netty] class NettyRpcEnv( } } - if (remoteAddr == address) { - val p = Promise[Any]() - p.future.onComplete { - case Success(response) => onSuccess(response) - case Failure(e) => onFailure(e) - }(ThreadUtils.sameThread) - dispatcher.postLocalMessage(message, p) - } else { - val rpcMessage = RpcOutboxMessage(serialize(message), - onFailure, - (client, response) => onSuccess(deserialize[Any](client, response))) - postToOutbox(message.receiver, rpcMessage) - promise.future.onFailure { - case _: TimeoutException => rpcMessage.onTimeout() - case _ => + try { + if (remoteAddr == address) { + val p = Promise[Any]() + p.future.onComplete { + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) + }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) + } else { + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) + } + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) }(ThreadUtils.sameThread) + } catch { + case NonFatal(e) => + onFailure(e) } - - val timeoutCancelable = timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure( - new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - promise.future.onComplete { v => - timeoutCancelable.cancel(true) - }(ThreadUtils.sameThread) promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } From 1dac964c1b996d38c65818414fc8401961a1de8a Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 26 Jan 2016 17:31:19 -0800 Subject: [PATCH 029/131] =?UTF-8?q?[SPARK-11622][MLLIB]=20Make=20LibSVMRel?= =?UTF-8?q?ation=20extends=20HadoopFsRelation=20and=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … Add LibSVMOutputWriter The behavior of LibSVMRelation is not changed except adding LibSVMOutputWriter * Partition is still not supported * Multiple input paths is not supported Author: Jeff Zhang Closes #9595 from zjffdu/SPARK-11622. --- .../ml/source/libsvm/LibSVMRelation.scala | 102 +++++++++++++++--- .../source/libsvm/LibSVMRelationSuite.scala | 23 +++- project/MimaExcludes.scala | 4 + 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 1bed542c4031..b9c364b05dc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -17,16 +17,21 @@ package org.apache.spark.ml.source.libsvm +import java.io.IOException + import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types._ /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -37,14 +42,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging with Serializable { - - override def schema: StructType = StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil - ) + extends HadoopFsRelation with Serializable { - override def buildScan(): RDD[Row] = { + override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]) + : RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) val sparse = vectorType == "sparse" @@ -66,8 +67,63 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val case _ => false } + + override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job): + _root_.org.apache.spark.sql.sources.OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new LibSVMOutputWriter(path, dataSchema, context) + } + } + } + + override def paths: Array[String] = Array(path) + + override def dataSchema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil) } + +private[libsvm] class LibSVMOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = { + val label = row.get(0) + val vector = row.get(1).asInstanceOf[Vector] + val sb = new StringBuilder(label.toString) + vector.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" + } + buffer.set(sb.mkString) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} /** * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and @@ -99,16 +155,32 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") -class DefaultSource extends RelationProvider with DataSourceRegister { +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - @Since("1.6.0") - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) - : BaseRelation = { - val path = parameters.getOrElse("path", - throw new IllegalArgumentException("'path' must be specified")) + private def verifySchema(dataSchema: StructType): Unit = { + if (dataSchema.size != 2 || + (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) + || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") + } + } + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + val path = if (paths.length == 1) paths(0) + else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data") + else throw new IOException("Multiple input paths are not supported for libsvm data") + if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) { + throw new IOException("Partition is not supported for libsvm data") + } + dataSchema.foreach(verifySchema(_)) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 5f4d5f11bdd6..528d9e21cb1f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{File, IOException} import com.google.common.base.Charsets import com.google.common.io.Files @@ -25,6 +25,7 @@ import com.google.common.io.Files import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -82,4 +83,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } + + test("write libsvm data and read it again") { + val df = sqlContext.read.format("libsvm").load(path) + val tempDir2 = Utils.createTempDir() + val writepath = tempDir2.toURI.toString + df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) + + val df2 = sqlContext.read.format("libsvm").load(writepath) + val row1 = df2.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("write libsvm data failed due to invalid schema") { + val df = sqlContext.read.format("text").load(path) + val e = intercept[IOException] { + df.write.format("libsvm").save(path + "_2") + } + assert(e.getMessage.contains("Illegal schema for libsvm data")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 643bee69694d..fc7dc2181de8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -203,6 +203,10 @@ object MimaExcludes { // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") + ) ++ Seq( + // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") ) case v if v.startsWith("1.6") => Seq( From 555127387accdd7c1cf236912941822ba8af0a52 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 26 Jan 2016 17:34:01 -0800 Subject: [PATCH 030/131] [SPARK-12854][SQL] Implement complex types support in ColumnarBatch This patch adds support for complex types for ColumnarBatch. ColumnarBatch supports structs and arrays. There is a simple mapping between the richer catalyst types to these two. Strings are treated as an array of bytes. ColumnarBatch will contain a column for each node of the schema. Non-complex schemas consists of just leaf nodes. Structs represent an internal node with one child for each field. Arrays are internal nodes with one child. Structs just contain nullability. Arrays contain offsets and lengths into the child array. This structure is able to handle arbitrary nesting. It has the key property that we maintain columnar throughout and that primitive types are only stored in the leaf nodes and contiguous across rows. For example, if the schema is ``` array> ``` There are three columns in the schema. The internal nodes each have one children. The leaf node contains all the int data stored consecutively. As part of this, this patch adds append APIs in addition to the Put APIs (e.g. putLong(rowid, v) vs appendLong(v)). These APIs are necessary when the batch contains variable length elements. The vectors are not fixed length and will grow as necessary. This should make the usage a lot simpler for the writer. Author: Nong Li Closes #10820 from nongli/spark-12854. --- .../sql/catalyst/expressions/UnsafeRow.java | 7 +- .../expressions/SpecificMutableRow.scala | 3 +- .../spark/sql/RandomDataGenerator.scala | 94 ++- .../spark/sql/RandomDataGeneratorSuite.scala | 4 +- .../GenerateUnsafeRowJoinerSuite.scala | 5 +- .../execution/vectorized/ColumnVector.java | 630 +++++++++++++++++- .../vectorized/ColumnVectorUtils.java | 126 ++++ .../execution/vectorized/ColumnarBatch.java | 70 +- .../vectorized/OffHeapColumnVector.java | 157 ++++- .../vectorized/OnHeapColumnVector.java | 169 ++++- .../UnsafeKVExternalSorterSuite.scala | 4 +- .../vectorized/ColumnarBatchBenchmark.scala | 78 ++- .../vectorized/ColumnarBatchSuite.scala | 397 ++++++++++- .../execution/AggregationQuerySuite.scala | 3 +- .../sql/sources/hadoopFsRelationSuites.scala | 3 +- .../org/apache/spark/unsafe/Platform.java | 11 + 16 files changed, 1671 insertions(+), 90 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1a351933a366..a88bcbfdb7cc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -68,6 +68,10 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields + 63)/ 64) * 8; } + public static int calculateFixedPortionByteSize(int numFields) { + return 8 * numFields + calculateBitSetWidthInBytes(numFields); + } + /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ @@ -596,10 +600,9 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { + if (i != 0) build.append(','); build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); - build.append(','); } - build.deleteCharAt(build.length() - 1); build.append(']'); return build.toString(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 475cbe005a6e..4615c55d676f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * A parent class for mutable container objects that are reused when the values are changed, @@ -212,6 +211,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) def this() = this(Seq.empty) + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + override def numFields: Int = values.length override def setNullAt(i: Int): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 7614f055e9c0..55efea80d1a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -21,6 +21,7 @@ import java.lang.Double.longBitsToDouble import java.lang.Float.intBitsToFloat import java.math.MathContext +import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -74,13 +75,47 @@ object RandomDataGenerator { * @param numFields the number of fields in this schema * @param acceptedTypes types to draw from. */ - def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = { + def randomSchema(rand: Random, numFields: Int, acceptedTypes: Seq[DataType]): StructType = { StructType(Seq.tabulate(numFields) { i => - val dt = acceptedTypes(Random.nextInt(acceptedTypes.size)) - StructField("col_" + i, dt, nullable = true) + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + StructField("col_" + i, dt, nullable = rand.nextBoolean()) }) } + /** + * Returns a random nested schema. This will randomly generate structs and arrays drawn from + * acceptedTypes. + */ + def randomNestedSchema(rand: Random, totalFields: Int, acceptedTypes: Seq[DataType]): + StructType = { + val fields = mutable.ArrayBuffer.empty[StructField] + var i = 0 + var numFields = totalFields + while (numFields > 0) { + val v = rand.nextInt(3) + if (v == 0) { + // Simple type: + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + fields += new StructField("col_" + i, dt, rand.nextBoolean()) + numFields -= 1 + } else if (v == 1) { + // Array + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + fields += new StructField("col_" + i, ArrayType(dt), rand.nextBoolean()) + numFields -= 1 + } else { + // Struct + // TODO: do empty structs make sense? + val n = Math.max(rand.nextInt(numFields), 1) + val nested = randomNestedSchema(rand, n, acceptedTypes) + fields += new StructField("col_" + i, nested, rand.nextBoolean()) + numFields -= n + } + i += 1 + } + StructType(fields) + } + /** * Returns a function which generates random values for the given [[DataType]], or `None` if no * random data generator is defined for that data type. The generated values will use an external @@ -90,16 +125,13 @@ object RandomDataGenerator { * * @param dataType the type to generate values for * @param nullable whether null values should be generated - * @param seed an optional seed for the random number generator + * @param rand an optional random number generator * @return a function which can be called to generate random values. */ def forType( dataType: DataType, nullable: Boolean = true, - seed: Option[Long] = None): Option[() => Any] = { - val rand = new Random() - seed.foreach(rand.setSeed) - + rand: Random = new Random): Option[() => Any] = { val valueGenerator: Option[() => Any] = dataType match { case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) case BinaryType => Some(() => { @@ -165,15 +197,15 @@ object RandomDataGenerator { rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) case ArrayType(elementType, containsNull) => { - forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { + forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } } case MapType(keyType, valueType, valueContainsNull) => { for ( - keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); + keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- - forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) + forType(valueType, nullable = valueContainsNull, rand) ) yield { () => { Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap @@ -182,7 +214,7 @@ object RandomDataGenerator { } case StructType(fields) => { val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => - forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) + forType(field.dataType, nullable = field.nullable, rand) } if (maybeFieldGenerators.forall(_.isDefined)) { val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) @@ -192,7 +224,7 @@ object RandomDataGenerator { } } case udt: UserDefinedType[_] => { - val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed) + val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand) // Because random data generator at here returns scala value, we need to // convert it to catalyst value to call udt's deserialize. val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) @@ -229,4 +261,40 @@ object RandomDataGenerator { } } } + + // Generates a random row for `schema`. + def randomRow(rand: Random, schema: StructType): Row = { + val fields = mutable.ArrayBuffer.empty[Any] + schema.fields.foreach { f => + f.dataType match { + case ArrayType(childType, nullable) => { + val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + val arr = mutable.ArrayBuffer.empty[Any] + val n = 1// rand.nextInt(10) + var i = 0 + val generator = RandomDataGenerator.forType(childType, nullable, rand) + assert(generator.isDefined, "Unsupported type") + val gen = generator.get + while (i < n) { + arr += gen() + i += 1 + } + arr + } + fields += data + } + case StructType(children) => { + fields += randomRow(rand, StructType(children)) + } + case _ => + val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) + assert(generator.isDefined, "Unsupported type") + val gen = generator.get + fields += gen() + } + } + Row.fromSeq(fields) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index cccac7efa09e..b8ccdf7516d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ @@ -32,7 +34,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { */ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) - val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { + val generator = RandomDataGenerator.forType(dataType, nullable, new Random(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index 59729e7646be..9f19745cefd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -74,8 +74,9 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) { info(s"schema size $numFields1, $numFields2") - val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes) - val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes) + val random = new Random() + val schema1 = RandomDataGenerator.randomSchema(random, numFields1, candidateTypes) + val schema2 = RandomDataGenerator.randomSchema(random, numFields2, candidateTypes) // Create the converters needed to convert from external row to internal row and to UnsafeRows. val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 85509751dbbe..c119758d68b3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -17,22 +17,45 @@ package org.apache.spark.sql.execution.vectorized; import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +import org.apache.commons.lang.NotImplementedException; /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. * The batched versions are preferable whenever possible. * - * Most of the APIs take the rowId as a parameter. This is the local 0-based row id for values + * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these + * columns have child columns. All of the data is stored in the child columns and the parent column + * contains nullability, and in the case of Arrays, the lengths and offsets into the child column. + * Lengths and offsets are encoded identically to INTs. + * Maps are just a special case of a two field struct. + * Strings are handled as an Array of ByteType. + * + * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the + * responsibility of the caller to call reserve() to ensure there is enough room before adding + * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), + * the lengths are known up front. + * + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in the current RowBatch. * * A ColumnVector should be considered immutable once originally created. In other words, it is not * valid to call put APIs after reads until reset() is called. + * + * ColumnVectors are intended to be reused. */ public abstract class ColumnVector { /** - * Allocates a column with each element of size `width` either on or off heap. + * Allocates a column to store elements of `type` on or off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. */ public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { if (mode == MemoryMode.OFF_HEAP) { @@ -42,13 +65,265 @@ public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode } } + /** + * Holder object to return an array. This object is intended to be reused. Callers should + * copy the data out if it needs to be stored. + */ + public static final class Array extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + public final ColumnVector data; + public int length; + public int offset; + + // Populate if binary data is required for the Array. This is stored here as an optimization + // for string data. + public byte[] byteArray; + public int byteArrayOffset; + + // Reused staging buffer, used for loading from offheap. + protected byte[] tmpByteArray = new byte[1]; + + protected Array(ColumnVector data) { + this.data = data; + } + + @Override + public final int numElements() { return length; } + + @Override + public ArrayData copy() { + throw new NotImplementedException(); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + + if (dt instanceof ByteType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getByte(offset + i); + } + } + } else if (dt instanceof IntegerType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getInt(offset + i); + } + } + } else if (dt instanceof DoubleType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getDouble(offset + i); + } + } + } else if (dt instanceof LongType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getLong(offset + i); + } + } + } else if (dt instanceof StringType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + } + } + } else { + throw new NotImplementedException("Type " + dt); + } + return list; + } + + @Override + public final boolean isNullAt(int ordinal) { return data.getIsNull(offset + ordinal); } + + @Override + public final boolean getBoolean(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + throw new NotImplementedException(); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + Array child = data.getByteArray(offset + ordinal); + return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + } + + @Override + public byte[] getBinary(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + throw new NotImplementedException(); + } + + @Override + public ArrayData getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + } + + /** + * Holder object to return a struct. This object is intended to be reused. + */ + public static final class Struct extends InternalRow { + // The fields that make up this struct. For example, if the struct had 2 int fields, the access + // to it would be: + // int f1 = fields[0].getInt[rowId] + // int f2 = fields[1].getInt[rowId] + public final ColumnVector[] fields; + + @Override + public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); } + + @Override + public boolean getBoolean(int ordinal) { + throw new NotImplementedException(); + } + + public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); } + + @Override + public short getShort(int ordinal) { + throw new NotImplementedException(); + } + + public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); } + public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { + throw new NotImplementedException(); + } + + public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + throw new NotImplementedException(); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + Array a = getByteArray(ordinal); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } + + @Override + public byte[] getBinary(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return fields[ordinal].getStruct(rowId); + } + + public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + + public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); } + public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); } + + @Override + public final int numFields() { + return fields.length; + } + + @Override + public InternalRow copy() { + throw new NotImplementedException(); + } + + @Override + public boolean anyNull() { + throw new NotImplementedException(); + } + + protected int rowId; + + protected Struct(ColumnVector[] fields) { + this.fields = fields; + } + } + + /** + * Returns the data type of this column. + */ public final DataType dataType() { return type; } /** * Resets this column for writing. The currently stored values are no longer accessible. */ public void reset() { + if (childColumns != null) { + for (ColumnVector c: childColumns) { + c.reset(); + } + } numNulls = 0; + elementsAppended = 0; if (anyNullsSet) { putNotNulls(0, capacity); anyNullsSet = false; @@ -61,6 +336,12 @@ public void reset() { */ public abstract void close(); + /* + * Ensures that there is enough storage to store capcity elements. That is, the put() APIs + * must work for all rowIds < capcity. + */ + public abstract void reserve(int capacity); + /** * Returns the number of nulls in this column. */ @@ -96,6 +377,26 @@ public void reset() { */ public abstract boolean getIsNull(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putByte(int rowId, byte value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBytes(int rowId, int count, byte value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract byte getByte(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -118,10 +419,36 @@ public void reset() { public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Returns the integer for rowId. + * Returns the value for rowId. */ public abstract int getInt(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putLong(int rowId, long value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putLongs(int rowId, int count, long value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 8-byte little endian longs. + */ + public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract long getLong(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -145,14 +472,248 @@ public void reset() { public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); /** - * Returns the double for rowId. + * Returns the value for rowId. */ public abstract double getDouble(int rowId); + /** + * Puts a byte array that already exists in this column. + */ + public abstract void putArray(int rowId, int offset, int length); + + /** + * Returns the length of the array at rowid. + */ + public abstract int getArrayLength(int rowId); + + /** + * Returns the offset of the array at rowid. + */ + public abstract int getArrayOffset(int rowId); + + /** + * Returns a utility object to get structs. + */ + public Struct getStruct(int rowId) { + resultStruct.rowId = rowId; + return resultStruct; + } + + /** + * Returns the array at rowid. + */ + public final Array getArray(int rowId) { + resultArray.length = getArrayLength(rowId); + resultArray.offset = getArrayOffset(rowId); + return resultArray; + } + + /** + * Loads the data into array.byteArray. + */ + public abstract void loadBytes(Array array); + + /** + * Sets the value at rowId to `value`. + */ + public abstract int putByteArray(int rowId, byte[] value, int offset, int count); + public final int putByteArray(int rowId, byte[] value) { + return putByteArray(rowId, value, 0, value.length); + } + + /** + * Returns the value for rowId. + */ + public final Array getByteArray(int rowId) { + Array array = getArray(rowId); + array.data.loadBytes(array); + return array; + } + + /** + * Append APIs. These APIs all behave similarly and will append data to the current vector. It + * is not valid to mix the put and append APIs. The append APIs are slower and should only be + * used if the sizes are not known up front. + * In all these cases, the return value is the rowId for the first appended element. + */ + public final int appendNull() { + assert (!(dataType() instanceof StructType)); // Use appendStruct() + reserve(elementsAppended + 1); + putNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNotNull() { + reserve(elementsAppended + 1); + putNotNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendNotNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNotNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendByte(byte v) { + reserve(elementsAppended + 1); + putByte(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBytes(int count, byte v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBytes(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendBytes(int length, byte[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putBytes(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendInt(int v) { + reserve(elementsAppended + 1); + putInt(elementsAppended, v); + return elementsAppended++; + } + + public final int appendInts(int count, int v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putInts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendInts(int length, int[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putInts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendLong(long v) { + reserve(elementsAppended + 1); + putLong(elementsAppended, v); + return elementsAppended++; + } + + public final int appendLongs(int count, long v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putLongs(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendLongs(int length, long[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putLongs(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendDouble(double v) { + reserve(elementsAppended + 1); + putDouble(elementsAppended, v); + return elementsAppended++; + } + + public final int appendDoubles(int count, double v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putDoubles(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendDoubles(int length, double[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putDoubles(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendByteArray(byte[] value, int offset, int length) { + int copiedOffset = arrayData().appendBytes(length, value, offset); + reserve(elementsAppended + 1); + putArray(elementsAppended, copiedOffset, length); + return elementsAppended++; + } + + public final int appendArray(int length) { + reserve(elementsAppended + 1); + putArray(elementsAppended, arrayData().elementsAppended, length); + return elementsAppended++; + } + + /** + * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this + * recursively appends a NULL to its children. + * We don't have this logic as the general appendNull implementation to optimize the more + * common non-struct case. + */ + public final int appendStruct(boolean isNull) { + if (isNull) { + appendNull(); + for (ColumnVector c: childColumns) { + if (c.type instanceof StructType) { + c.appendStruct(true); + } else { + c.appendNull(); + } + } + } else { + appendNotNull(); + } + return elementsAppended; + } + + /** + * Returns the data for the underlying array. + */ + public final ColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + + /** + * Returns the elements appended. + */ + public int getElementsAppended() { return elementsAppended; } + /** * Maximum number of rows that can be stored in this column. */ - protected final int capacity; + protected int capacity; + + /** + * Data type for this column. + */ + protected final DataType type; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. @@ -166,12 +727,63 @@ public void reset() { protected boolean anyNullsSet; /** - * Data type for this column. + * Default size of each array length value. This grows as necessary. */ - protected final DataType type; + protected static final int DEFAULT_ARRAY_LENGTH = 4; + + /** + * Current write cursor (row index) when appending data. + */ + protected int elementsAppended; - protected ColumnVector(int capacity, DataType type) { + /** + * If this is a nested type (array or struct), the column for the child data. + */ + protected final ColumnVector[] childColumns; + + /** + * Reusable Array holder for getArray(). + */ + protected final Array resultArray; + + /** + * Reusable Struct holder for getStruct(). + */ + protected final Struct resultStruct; + + /** + * Sets up the common state and also handles creating the child columns if this is a nested + * type. + */ + protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.capacity = capacity; this.type = type; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) { + DataType childType; + int childCapacity = capacity; + if (type instanceof ArrayType) { + childType = ((ArrayType)type).elementType(); + } else { + childType = DataTypes.ByteType; + childCapacity *= DEFAULT_ARRAY_LENGTH; + } + this.childColumns = new ColumnVector[1]; + this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); + this.resultArray = new Array(this.childColumns[0]); + this.resultStruct = null; + } else if (type instanceof StructType) { + StructType st = (StructType)type; + this.childColumns = new ColumnVector[st.fields().length]; + for (int i = 0; i < childColumns.length; ++i) { + this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + } + this.resultArray = null; + this.resultStruct = new Struct(this.childColumns); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java new file mode 100644 index 000000000000..6c651a759d25 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.*; + +import org.apache.commons.lang.NotImplementedException; + +/** + * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly + * for debugging or other non-performance critical paths. + * These utilities are mostly used to convert ColumnVectors into other formats. + */ +public class ColumnVectorUtils { + public static String toString(ColumnVector.Array a) { + return new String(a.byteArray, a.byteArrayOffset, a.length); + } + + /** + * Returns the array data as the java primitive array. + * For example, an array of IntegerType will return an int[]. + * Throws exceptions for unhandled schemas. + */ + public static Object toPrimitiveJavaArray(ColumnVector.Array array) { + DataType dt = array.data.dataType(); + if (dt instanceof IntegerType) { + int[] result = new int[array.length]; + ColumnVector data = array.data; + for (int i = 0; i < result.length; i++) { + if (data.getIsNull(array.offset + i)) { + throw new RuntimeException("Cannot handle NULL values."); + } + result[i] = data.getInt(array.offset + i); + } + return result; + } else { + throw new NotImplementedException(); + } + } + + private static void appendValue(ColumnVector dst, DataType t, Object o) { + if (o == null) { + dst.appendNull(); + } else { + if (t == DataTypes.ByteType) { + dst.appendByte(((Byte)o).byteValue()); + } else if (t == DataTypes.IntegerType) { + dst.appendInt(((Integer)o).intValue()); + } else if (t == DataTypes.LongType) { + dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.DoubleType) { + dst.appendDouble(((Double)o).doubleValue()); + } else if (t == DataTypes.StringType) { + byte[] b =((String)o).getBytes(); + dst.appendByteArray(b, 0, b.length); + } else { + throw new NotImplementedException("Type " + t); + } + } + } + + private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) { + if (t instanceof ArrayType) { + ArrayType at = (ArrayType)t; + if (src.isNullAt(fieldIdx)) { + dst.appendNull(); + } else { + List values = src.getList(fieldIdx); + dst.appendArray(values.size()); + for (Object o : values) { + appendValue(dst.arrayData(), at.elementType(), o); + } + } + } else if (t instanceof StructType) { + StructType st = (StructType)t; + if (src.isNullAt(fieldIdx)) { + dst.appendStruct(true); + } else { + dst.appendStruct(false); + Row c = src.getStruct(fieldIdx); + for (int i = 0; i < st.fields().length; i++) { + appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + } + } + } else { + appendValue(dst, t, src.get(fieldIdx)); + } + } + + /** + * Converts an iterator of rows into a single ColumnBatch. + */ + public static ColumnarBatch toBatch( + StructType schema, MemoryMode memMode, Iterator row) { + ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode); + int n = 0; + while (row.hasNext()) { + Row r = row.next(); + for (int i = 0; i < schema.fields().length; i++) { + appendValue(batch.column(i), schema.fields()[i].dataType(), r, i); + } + n++; + } + batch.setNumRows(n); + return batch; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 2c55f854c241..d558dae50c22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -21,12 +21,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -48,6 +46,7 @@ */ public final class ColumnarBatch { private static final int DEFAULT_BATCH_SIZE = 4 * 1024; + private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; private final StructType schema; private final int capacity; @@ -64,6 +63,10 @@ public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); } + public static ColumnarBatch allocate(StructType type) { + return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE); + } + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) { return new ColumnarBatch(schema, maxRows, memMode); } @@ -82,25 +85,53 @@ public void close() { * Adapter class to interop with existing components that expect internal row. A lot of * performance is lost with this translation. */ - public final class Row extends InternalRow { + public static final class Row extends InternalRow { private int rowId; + private final ColumnarBatch parent; + private final int fixedLenRowSize; + + private Row(ColumnarBatch parent) { + this.parent = parent; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + } /** * Marks this row as being filtered out. This means a subsequent iteration over the rows * in this batch will not include this row. */ public final void markFiltered() { - ColumnarBatch.this.markFiltered(rowId); + parent.markFiltered(rowId); } @Override public final int numFields() { - return ColumnarBatch.this.numCols(); + return parent.numCols(); } @Override + /** + * Revisit this. This is expensive. + */ public final InternalRow copy() { - throw new NotImplementedException(); + UnsafeRow row = new UnsafeRow(parent.numCols()); + row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize); + for (int i = 0; i < parent.numCols(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = parent.schema.fields()[i].dataType(); + if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else { + throw new RuntimeException("Not implemented."); + } + } + } + return row; } @Override @@ -110,7 +141,7 @@ public final boolean anyNull() { @Override public final boolean isNullAt(int ordinal) { - return ColumnarBatch.this.column(ordinal).getIsNull(rowId); + return parent.column(ordinal).getIsNull(rowId); } @Override @@ -119,9 +150,7 @@ public final boolean getBoolean(int ordinal) { } @Override - public final byte getByte(int ordinal) { - throw new NotImplementedException(); - } + public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); } @Override public final short getShort(int ordinal) { @@ -130,13 +159,11 @@ public final short getShort(int ordinal) { @Override public final int getInt(int ordinal) { - return ColumnarBatch.this.column(ordinal).getInt(rowId); + return parent.column(ordinal).getInt(rowId); } @Override - public final long getLong(int ordinal) { - throw new NotImplementedException(); - } + public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); } @Override public final float getFloat(int ordinal) { @@ -145,7 +172,7 @@ public final float getFloat(int ordinal) { @Override public final double getDouble(int ordinal) { - return ColumnarBatch.this.column(ordinal).getDouble(rowId); + return parent.column(ordinal).getDouble(rowId); } @Override @@ -155,7 +182,8 @@ public final Decimal getDecimal(int ordinal, int precision, int scale) { @Override public final UTF8String getUTF8String(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } @Override @@ -170,12 +198,12 @@ public final CalendarInterval getInterval(int ordinal) { @Override public final InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); + return parent.column(ordinal).getStruct(rowId); } @Override public final ArrayData getArray(int ordinal) { - throw new NotImplementedException(); + return parent.column(ordinal).getArray(rowId); } @Override @@ -194,7 +222,7 @@ public final Object get(int ordinal, DataType dataType) { */ public Iterator rowIterator() { final int maxRows = ColumnarBatch.this.numRows(); - final Row row = new Row(); + final Row row = new Row(this); return new Iterator() { int rowId = 0; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 6180dd308e5e..335124fd5a60 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,12 +18,18 @@ import java.nio.ByteOrder; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DoubleType; import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.commons.lang.NotImplementedException; + import org.apache.commons.lang.NotImplementedException; /** @@ -35,21 +41,21 @@ public final class OffHeapColumnVector extends ColumnVector { private long nulls; private long data; + // Set iff the type is array. + private long lengthData; + private long offsetData; + protected OffHeapColumnVector(int capacity, DataType type) { - super(capacity, type); + super(capacity, type, MemoryMode.OFF_HEAP); if (!ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) { throw new NotImplementedException("Only little endian is supported."); } + nulls = 0; + data = 0; + lengthData = 0; + offsetData = 0; - this.nulls = Platform.allocateMemory(capacity); - if (type instanceof IntegerType) { - this.data = Platform.allocateMemory(capacity * 4); - } else if (type instanceof DoubleType) { - this.data = Platform.allocateMemory(capacity * 8); - } else { - throw new RuntimeException("Unhandled " + type); - } - anyNullsSet = true; + reserveInternal(capacity); reset(); } @@ -67,8 +73,12 @@ public long nullsNativeAddress() { public final void close() { Platform.freeMemory(nulls); Platform.freeMemory(data); + Platform.freeMemory(lengthData); + Platform.freeMemory(offsetData); nulls = 0; data = 0; + lengthData = 0; + offsetData = 0; } // @@ -111,6 +121,33 @@ public final boolean getIsNull(int rowId) { return Platform.getByte(null, nulls + rowId) == 1; } + // + // APIs dealing with Bytes + // + + @Override + public final void putByte(int rowId, byte value) { + Platform.putByte(null, data + rowId, value); + + } + + @Override + public final void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, value); + } + } + + @Override + public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count); + } + + @Override + public final byte getByte(int rowId) { + return Platform.getByte(null, data + rowId); + } + // // APIs dealing with ints // @@ -145,6 +182,40 @@ public final int getInt(int rowId) { return Platform.getInt(null, data + 4 * rowId); } + // + // APIs dealing with Longs + // + + @Override + public final void putLong(int rowId, long value) { + Platform.putLong(null, data + 8 * rowId, value); + } + + @Override + public final void putLongs(int rowId, int count, long value) { + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8) { + Platform.putLong(null, offset, value); + } + } + + @Override + public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, + null, data + 8 * rowId, count * 8); + } + + @Override + public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 8 * rowId, count * 8); + } + + @Override + public final long getLong(int rowId) { + return Platform.getLong(null, data + 8 * rowId); + } + // // APIs dealing with doubles // @@ -178,4 +249,70 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { public final double getDouble(int rowId) { return Platform.getDouble(null, data + rowId * 8); } + + // + // APIs dealing with Arrays. + // + @Override + public final void putArray(int rowId, int offset, int length) { + assert(offset >= 0 && offset + length <= childColumns[0].capacity); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, offset); + } + + @Override + public final int getArrayLength(int rowId) { + return Platform.getInt(null, lengthData + 4 * rowId); + } + + @Override + public final int getArrayOffset(int rowId) { + return Platform.getInt(null, offsetData + 4 * rowId); + } + + // APIs dealing with ByteArrays + @Override + public final int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, result); + return result; + } + + @Override + public final void loadBytes(Array array) { + if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; + Platform.copyMemory( + null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); + array.byteArray = array.tmpByteArray; + array.byteArrayOffset = 0; + } + + @Override + public final void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Split out the slow path. + private final void reserveInternal(int newCapacity) { + if (this.resultArray != null) { + this.lengthData = + Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + this.offsetData = + Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + } else if (type instanceof ByteType) { + this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + } else if (type instanceof IntegerType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + } else if (type instanceof LongType || type instanceof DoubleType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); + Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + capacity = newCapacity; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 76d9956c3842..8197fa11cd4c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,10 @@ */ package org.apache.spark.sql.execution.vectorized; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; import java.util.Arrays; /** @@ -37,19 +34,18 @@ public final class OnHeapColumnVector extends ColumnVector { private byte[] nulls; // Array for each type. Only 1 is populated for any type. + private byte[] byteData; private int[] intData; + private long[] longData; private double[] doubleData; + // Only set if type is Array. + private int[] arrayLengths; + private int[] arrayOffsets; + protected OnHeapColumnVector(int capacity, DataType type) { - super(capacity, type); - if (type instanceof IntegerType) { - this.intData = new int[capacity]; - } else if (type instanceof DoubleType) { - this.doubleData = new double[capacity]; - } else { - throw new RuntimeException("Unhandled " + type); - } - this.nulls = new byte[capacity]; + super(capacity, type, MemoryMode.ON_HEAP); + reserveInternal(capacity); reset(); } @@ -108,6 +104,32 @@ public final boolean getIsNull(int rowId) { return nulls[rowId] == 1; } + // + // APIs dealing with Bytes + // + + @Override + public final void putByte(int rowId, byte value) { + byteData[rowId] = value; + } + + @Override + public final void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = value; + } + } + + @Override + public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + System.arraycopy(src, srcIndex, byteData, rowId, count); + } + + @Override + public final byte getByte(int rowId) { + return byteData[rowId]; + } + // // APIs dealing with Ints // @@ -144,6 +166,43 @@ public final int getInt(int rowId) { return intData[rowId]; } + // + // APIs dealing with Longs + // + + @Override + public final void putLong(int rowId, long value) { + longData[rowId] = value; + } + + @Override + public final void putLongs(int rowId, int count, long value) { + for (int i = 0; i < count; ++i) { + longData[i + rowId] = value; + } + } + + @Override + public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + System.arraycopy(src, srcIndex, longData, rowId, count); + } + + @Override + public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i) { + longData[i + rowId] = Platform.getLong(src, srcOffset); + srcIndex += 8; + srcOffset += 8; + } + } + + @Override + public final long getLong(int rowId) { + return longData[rowId]; + } + + // // APIs dealing with doubles // @@ -173,4 +232,86 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { public final double getDouble(int rowId) { return doubleData[rowId]; } + + // + // APIs dealing with Arrays + // + + @Override + public final int getArrayLength(int rowId) { + return arrayLengths[rowId]; + } + @Override + public final int getArrayOffset(int rowId) { + return arrayOffsets[rowId]; + } + + @Override + public final void putArray(int rowId, int offset, int length) { + arrayOffsets[rowId] = offset; + arrayLengths[rowId] = length; + } + + @Override + public final void loadBytes(Array array) { + array.byteArray = byteData; + array.byteArrayOffset = array.offset; + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public final int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + arrayOffsets[rowId] = result; + arrayLengths[rowId] = length; + return result; + } + + @Override + public final void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Spilt this function out since it is the slow path. + private final void reserveInternal(int newCapacity) { + if (this.resultArray != null) { + int[] newLengths = new int[newCapacity]; + int[] newOffsets = new int[newCapacity]; + if (this.arrayLengths != null) { + System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + } + arrayLengths = newLengths; + arrayOffsets = newOffsets; + } else if (type instanceof ByteType) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; + } else if (type instanceof IntegerType) { + int[] newData = new int[newCapacity]; + if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + intData = newData; + } else if (type instanceof LongType) { + long[] newData = new long[newCapacity]; + if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + longData = newData; + } else if (type instanceof DoubleType) { + double[] newData = new double[newCapacity]; + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + doubleData = newData; + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + + byte[] newNulls = new byte[newCapacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + nulls = newNulls; + + capacity = newCapacity; + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 95c9550aebb0..8a95359d9de2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -40,8 +40,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val rand = new Random(42) for (i <- 0 until 6) { - val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes) - val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes) + val keySchema = RandomDataGenerator.randomSchema(rand, rand.nextInt(10) + 1, keyTypes) + val valueSchema = RandomDataGenerator.randomSchema(rand, rand.nextInt(10) + 1, valueTypes) testKVSorter(keySchema, valueSchema, spill = i > 3) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index bfe944d835bd..8efdf8adb042 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer +import scala.util.Random + import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -239,6 +241,26 @@ object ColumnarBatchBenchmark { Platform.freeMemory(buffer) } + // Adding values by appending, instead of putting. + val onHeapAppend = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.appendInt(i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + col.reset() + } + col.close + } + /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Int Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate @@ -253,6 +275,7 @@ object ColumnarBatchBenchmark { Column(off heap direct) 237.6 1379.12 1.05 X UnsafeRow (on heap) 414.6 790.35 0.60 X UnsafeRow (off heap) 487.2 672.58 0.51 X + Column On Heap Append 530.1 618.14 0.59 X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -265,6 +288,7 @@ object ColumnarBatchBenchmark { benchmark.addCase("Column(off heap direct)")(columnOffheapDirect) benchmark.addCase("UnsafeRow (on heap)")(unsafeRowOnheap) benchmark.addCase("UnsafeRow (off heap)")(unsafeRowOffheap) + benchmark.addCase("Column On Heap Append")(onHeapAppend) benchmark.run() } @@ -314,8 +338,60 @@ object ColumnarBatchBenchmark { benchmark.run() } + def stringAccess(iters: Long): Unit = { + val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + val random = new Random(0) + + def randomString(min: Int, max: Int): String = { + val len = random.nextInt(max - min) + min + val sb = new StringBuilder(len) + var i = 0 + while (i < len) { + sb.append(chars.charAt(random.nextInt(chars.length()))); + i += 1 + } + return sb.toString + } + + val minString = 3 + val maxString = 32 + val count = 4 * 1000 + + val data = Seq.fill(count)(randomString(minString, maxString)).map(_.getBytes).toArray + + def column(memoryMode: MemoryMode) = { i: Int => + val column = ColumnVector.allocate(count, BinaryType, memoryMode) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + column.putByteArray(i, data(i)) + i += 1 + } + i = 0 + while (i < count) { + sum += column.getByteArray(i).length + i += 1 + } + column.reset() + } + } + + /* + String Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------------- + On Heap 457.0 35.85 1.00 X + Off Heap 1206.0 13.59 0.38 X + */ + val benchmark = new Benchmark("String Read/Write", count * iters) + benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) + benchmark.addCase("Off Heap")(column(MemoryMode.OFF_HEAP)) + benchmark.run + } + def main(args: Array[String]): Unit = { intAccess(1024 * 40) booleanAccess(1024 * 40) + stringAccess(1024 * 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index d5e517c7f56b..215ca9ab6b77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution.vectorized +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform class ColumnarBatchSuite extends SparkFunSuite { @@ -74,6 +75,45 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Byte Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[Byte] + + val column = ColumnVector.allocate(1024, ByteType, memMode) + var idx = 0 + + val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + column.putBytes(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putBytes(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + column.putByte(idx, 9) + reference += 9 + idx += 1 + + column.putBytes(idx, 3, 4) + reference += 4 + reference += 4 + reference += 4 + idx += 3 + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getByte(v._2), "MemoryMode" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getByte(null, addr + v._2)) + } + } + }} + } + test("Int Apis") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -142,6 +182,76 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Long Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Long] + + val column = ColumnVector.allocate(1024, LongType, memMode) + var idx = 0 + + val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + column.putLongs(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putLongs(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + val littleEndian = new Array[Byte](16) + littleEndian(0) = 7 + littleEndian(1) = 1 + littleEndian(8) = 6 + littleEndian(10) = 1 + + column.putLongsLittleEndian(idx, 1, littleEndian, 8) + column.putLongsLittleEndian(idx + 1, 1, littleEndian, 0) + reference += 6 + (1 << 16) + reference += 7 + (1 << 8) + idx += 2 + + column.putLongsLittleEndian(idx, 2, littleEndian, 0) + reference += 7 + (1 << 8) + reference += 6 + (1 << 16) + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextLong() + column.putLong(idx, v) + reference += v + idx += 1 + } else { + + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + column.putLongs(idx, n, n + 1) + var i = 0 + while (i < n) { + reference += (n + 1) + i += 1 + } + idx += n + } + } + + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getLong(v._2), "idx=" + v._2 + + " Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) + } + } + }} + } + test("Double APIs") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -209,15 +319,150 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("String APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[String] + + val column = ColumnVector.allocate(6, BinaryType, memMode) + assert(column.arrayData().elementsAppended == 0) + var idx = 0 + + val values = ("Hello" :: "abc" :: Nil).toArray + column.putByteArray(idx, values(0).getBytes, 0, values(0).getBytes().length) + reference += values(0) + idx += 1 + assert(column.arrayData().elementsAppended == 5) + + column.putByteArray(idx, values(1).getBytes, 0, values(1).getBytes().length) + reference += values(1) + idx += 1 + assert(column.arrayData().elementsAppended == 8) + + // Just put llo + val offset = column.putByteArray(idx, values(0).getBytes, 2, values(0).getBytes().length - 2) + reference += "llo" + idx += 1 + assert(column.arrayData().elementsAppended == 11) + + // Put the same "ll" at offset. This should not allocate more memory in the column. + column.putArray(idx, offset, 2) + reference += "ll" + idx += 1 + assert(column.arrayData().elementsAppended == 11) + + // Put a long string + val s = "abcdefghijklmnopqrstuvwxyz" + column.putByteArray(idx, (s + s).getBytes) + reference += (s + s) + idx += 1 + assert(column.arrayData().elementsAppended == 11 + (s + s).length) + + reference.zipWithIndex.foreach { v => + assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) + assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + "MemoryMode" + memMode) + } + + column.reset() + assert(column.arrayData().elementsAppended == 0) + }} + } + + test("Int Array") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + + // Fill the underlying data with all the arrays back to back. + val data = column.arrayData(); + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 0) + column.putArray(3, 3, 3) + + val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + assert(a1 === Array(0)) + assert(a2 === Array(1, 2)) + assert(a3 === Array.empty[Int]) + assert(a4 === Array(3, 4, 5)) + + // Verify the ArrayData APIs + assert(column.getArray(0).length == 1) + assert(column.getArray(0).getInt(0) == 0) + + assert(column.getArray(1).length == 2) + assert(column.getArray(1).getInt(0) == 1) + assert(column.getArray(1).getInt(1) == 2) + + assert(column.getArray(2).length == 0) + + assert(column.getArray(3).length == 3) + assert(column.getArray(3).getInt(0) == 3) + assert(column.getArray(3).getInt(1) == 4) + assert(column.getArray(3).getInt(2) == 5) + + // Add a longer array which requires resizing + column.reset + val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + assert(data.capacity == 10) + data.reserve(array.length) + assert(data.capacity == array.length * 2) + data.putInts(0, array.length, array, 0) + column.putArray(0, 0, array.length) + assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + === array) + }} + } + + test("Struct Column") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val schema = new StructType().add("int", IntegerType).add("double", DoubleType) + val column = ColumnVector.allocate(1024, schema, memMode) + + val c1 = column.getChildColumn(0) + val c2 = column.getChildColumn(1) + assert(c1.dataType() == IntegerType) + assert(c2.dataType() == DoubleType) + + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val s = column.getStruct(0) + assert(s.fields(0).getInt(0) == 123) + assert(s.fields(0).getInt(1) == 456) + assert(s.fields(1).getDouble(0) == 3.45) + assert(s.fields(1).getDouble(1) == 5.67) + + assert(s.getInt(0) == 123) + assert(s.getDouble(1) == 3.45) + + val s2 = column.getStruct(1) + assert(s2.getInt(0) == 456) + assert(s2.getDouble(1) == 5.67) + }} + } + test("ColumnarBatch basic") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType() .add("intCol", IntegerType) .add("doubleCol", DoubleType) .add("intCol2", IntegerType) + .add("string", BinaryType) val batch = ColumnarBatch.allocate(schema, memMode) - assert(batch.numCols() == 3) + assert(batch.numCols() == 4) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) assert(batch.capacity() > 0) @@ -227,10 +472,11 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.column(0).putInt(0, 1) batch.column(1).putDouble(0, 1.1) batch.column(2).putNull(0) + batch.column(3).putByteArray(0, "Hello".getBytes) batch.setNumRows(1) // Verify the results of the row. - assert(batch.numCols() == 3) + assert(batch.numCols() == 4) assert(batch.numRows() == 1) assert(batch.numValidRows() == 1) assert(batch.rowIterator().hasNext == true) @@ -241,6 +487,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -251,6 +498,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) @@ -260,24 +508,27 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.numValidRows() == 0) assert(batch.rowIterator().hasNext == false) - // Reset and add 3 throws + // Reset and add 3 rows batch.reset() assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) assert(batch.rowIterator().hasNext == false) - // Add rows [NULL, 2.2, 2], [3, NULL, 3], [4, 4.4, 4] + // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] batch.column(0).putNull(0) batch.column(1).putDouble(0, 2.2) batch.column(2).putInt(0, 2) + batch.column(3).putByteArray(0, "abc".getBytes) batch.column(0).putInt(1, 3) batch.column(1).putNull(1) batch.column(2).putInt(1, 3) + batch.column(3).putByteArray(1, "".getBytes) batch.column(0).putInt(2, 4) batch.column(1).putDouble(2, 4.4) batch.column(2).putInt(2, 4) + batch.column(3).putByteArray(2, "world".getBytes) batch.setNumRows(3) def rowEquals(x: InternalRow, y: Row): Unit = { @@ -289,30 +540,152 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(x.isNullAt(2) == y.isNullAt(2)) if (!x.isNullAt(2)) assert(x.getInt(2) == y.getInt(2)) + + assert(x.isNullAt(3) == y.isNullAt(3)) + if (!x.isNullAt(3)) assert(x.getString(3) == y.getString(3)) } + // Verify assert(batch.numRows() == 3) assert(batch.numValidRows() == 3) val it2 = batch.rowIterator() - rowEquals(it2.next(), Row(null, 2.2, 2)) - rowEquals(it2.next(), Row(3, null, 3)) - rowEquals(it2.next(), Row(4, 4.4, 4)) + rowEquals(it2.next(), Row(null, 2.2, 2, "abc")) + rowEquals(it2.next(), Row(3, null, 3, "")) + rowEquals(it2.next(), Row(4, 4.4, 4, "world")) assert(!it.hasNext) // Filter out some rows and verify batch.markFiltered(1) assert(batch.numValidRows() == 2) val it3 = batch.rowIterator() - rowEquals(it3.next(), Row(null, 2.2, 2)) - rowEquals(it3.next(), Row(4, 4.4, 4)) + rowEquals(it3.next(), Row(null, 2.2, 2, "abc")) + rowEquals(it3.next(), Row(4, 4.4, 4, "world")) assert(!it.hasNext) batch.markFiltered(2) assert(batch.numValidRows() == 1) val it4 = batch.rowIterator() - rowEquals(it4.next(), Row(null, 2.2, 2)) + rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) batch.close }} } + + + private def doubleEquals(d1: Double, d2: Double): Boolean = { + if (d1.isNaN && d2.isNaN) { + true + } else { + d1 == d2 + } + } + + private def compareStruct(fields: Seq[StructField], r1: InternalRow, r2: Row, seed: Long) { + fields.zipWithIndex.foreach { v => { + assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) + if (!r1.isNullAt(v._2)) { + v._1.dataType match { + case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) + case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), + "Seed = " + seed) + case StringType => + assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case ArrayType(childType, n) => + val a1 = r1.getArray(v._2).array + val a2 = r2.getList(v._2).toArray + assert(a1.length == a2.length, "Seed = " + seed) + childType match { + case DoubleType => { + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]), + "Seed = " + seed) + i += 1 + } + } + case _ => assert(a1 === a2, "Seed = " + seed) + } + case StructType(childFields) => + compareStruct(childFields, r1.getStruct(v._2, fields.length), r2.getStruct(v._2), seed) + case _ => + throw new NotImplementedError("Not implemented " + v._1.dataType) + } + } + }} + } + + test("Convert rows") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val rows = Row(1, 2L, "a", 1.2, 'b'.toByte) :: Row(4, 5L, "cd", 2.3, 'a'.toByte) :: Nil + val schema = new StructType() + .add("i1", IntegerType) + .add("l2", LongType) + .add("string", StringType) + .add("d", DoubleType) + .add("b", ByteType) + + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + assert(batch.numRows() == 2) + assert(batch.numCols() == 5) + + val it = batch.rowIterator() + val referenceIt = rows.iterator + while (it.hasNext) { + compareStruct(schema, it.next(), referenceIt.next(), 0) + } + batch.close() + } + }} + + /** + * This test generates a random schema data, serializes it to column batches and verifies the + * results. + */ + def testRandomRows(flatSchema: Boolean, numFields: Int) { + // TODO: add remaining types. Figure out why StringType doesn't work on jenkins. + val types = Array(ByteType, IntegerType, LongType, DoubleType) + val seed = System.nanoTime() + val NUM_ROWS = 500 + val NUM_ITERS = 1000 + val random = new Random(seed) + var i = 0 + while (i < NUM_ITERS) { + val schema = if (flatSchema) { + RandomDataGenerator.randomSchema(random, numFields, types) + } else { + RandomDataGenerator.randomNestedSchema(random, numFields, types) + } + val rows = mutable.ArrayBuffer.empty[Row] + var j = 0 + while (j < NUM_ROWS) { + val row = RandomDataGenerator.randomRow(random, schema) + rows += row + j += 1 + } + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + assert(batch.numRows() == NUM_ROWS) + + val it = batch.rowIterator() + val referenceIt = rows.iterator + var k = 0 + while (it.hasNext) { + compareStruct(schema, it.next(), referenceIt.next(), seed) + k += 1 + } + batch.close() + }} + i += 1 + } + } + + test("Random flat schema") { + testRandomRows(true, 10) + } + + test("Random nested schema") { + testRandomRows(false, 30) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 76b36aa89182..3e4cf3f79e57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import scala.util.Random import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -879,7 +880,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te RandomDataGenerator.forType( dataType = schemaForGenerator, nullable = true, - seed = Some(System.nanoTime())) + new Random(System.nanoTime())) val dataGenerator = maybeDataGenerator .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 3f9ecf6965e1..1a4b3ece72a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.collection.JavaConverters._ +import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -122,7 +123,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val dataGenerator = RandomDataGenerator.forType( dataType = dataType, nullable = true, - seed = Some(System.nanoTime()) + new Random(System.nanoTime()) ).getOrElse { fail(s"Failed to create data generator for schema $dataType") } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 0d6b215fe5aa..b29bf6a464b3 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -105,6 +105,17 @@ public static void freeMemory(long address) { _UNSAFE.freeMemory(address); } + public static long reallocateMemory(long address, long oldSize, long newSize) { + long newMemory = _UNSAFE.allocateMemory(newSize); + copyMemory(null, address, null, newMemory, oldSize); + freeMemory(address); + return newMemory; + } + + public static void setMemory(long address, byte value, long size) { + _UNSAFE.setMemory(address, size, value); + } + public static void copyMemory( Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy From b72611f20a03c790b6fd341b6ffdb3b5437609ee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 26 Jan 2016 17:59:05 -0800 Subject: [PATCH 031/131] [SPARK-7780][MLLIB] intercept in logisticregressionwith lbfgs should not be regularized The intercept in Logistic Regression represents a prior on categories which should not be regularized. In MLlib, the regularization is handled through Updater, and the Updater penalizes all the components without excluding the intercept which resulting poor training accuracy with regularization. The new implementation in ML framework handles this properly, and we should call the implementation in ML from MLlib since majority of users are still using MLlib api. Note that both of them are doing feature scalings to improve the convergence, and the only difference is ML version doesn't regularize the intercept. As a result, when lambda is zero, they will converge to the same solution. Previously partially reviewed at https://github.com/apache/spark/pull/6386#issuecomment-168781424 re-opening for dbtsai to review. Author: Holden Karau Author: Holden Karau Closes #10788 from holdenk/SPARK-7780-intercept-in-logisticregressionwithLBFGS-should-not-be-regularized. --- .../classification/LogisticRegression.scala | 36 ++++++-- .../classification/LogisticRegression.scala | 82 ++++++++++++++++++- .../spark/mllib/optimization/LBFGS.scala | 28 +++++++ .../GeneralizedLinearAlgorithm.scala | 34 ++++---- .../ml/classification/OneVsRestSuite.scala | 2 +- .../LogisticRegressionSuite.scala | 25 ++++-- 6 files changed, 179 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c98a78a515dc..9b2340a1f16f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -247,15 +247,27 @@ class LogisticRegression @Since("1.2.0") ( @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds - override protected def train(dataset: DataFrame): LogisticRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. + private var optInitialModel: Option[LogisticRegressionModel] = None + + /** @group setParam */ + private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { + this.optInitialModel = Some(model) + this + } + + override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + train(dataset, handlePersistence) + } + + protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): + LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) val (summarizer, labelSummarizer) = { @@ -343,7 +355,21 @@ class LogisticRegression @Since("1.2.0") ( val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) - if ($(fitIntercept)) { + if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) { + val vec = optInitialModel.get.coefficients + logWarning( + s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}") + } + + if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) { + val initialCoefficientsWithInterceptArray = initialCoefficientsWithIntercept.toArray + optInitialModel.get.coefficients.foreachActive { case (index, value) => + initialCoefficientsWithInterceptArray(index) = value + } + if ($(fitIntercept)) { + initialCoefficientsWithInterceptArray(numFeatures) == optInitialModel.get.intercept + } + } else if ($(fitIntercept)) { /* For binary logistic regression, when we initialize the coefficients as zeros, it will converge faster if we initialize the intercept such that @@ -434,7 +460,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { */ @Since("1.4.0") @Experimental -class LogisticRegressionModel private[ml] ( +class LogisticRegressionModel private[spark] ( @Since("1.4.0") override val uid: String, @Since("1.6.0") val coefficients: Vector, @Since("1.3.0") val intercept: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2a7697b5a79c..bf68e3edd7ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -19,15 +19,18 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.classification.impl.GLMClassificationModel -import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} +import org.apache.spark.mllib.util.MLUtils.appendBias import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel /** * Classification model trained using Multinomial/Binary Logistic Regression. @@ -332,6 +335,13 @@ object LogisticRegressionWithSGD { * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. + * + * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization + * penalty to all elements including the intercept. If this is called with one of + * standard updaters (L1Updater, or SquaredL2Updater) this is translated + * into a call to ml.LogisticRegression, otherwise this will use the existing mllib + * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the + * intercept. */ @Since("1.1.0") class LogisticRegressionWithLBFGS @@ -374,4 +384,72 @@ class LogisticRegressionWithLBFGS new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) } } + + /** + * Run Logistic Regression with the configured parameters on an input RDD + * of LabeledPoint entries. + * + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * If using ml implementation, uses ml code to generate initial weights. + */ + override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = { + run(input, generateInitialWeights(input), userSuppliedWeights = false) + } + + /** + * Run Logistic Regression with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + * + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * Uses user provided weights. + */ + override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { + run(input, initialWeights, userSuppliedWeights = true) + } + + private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean): + LogisticRegressionModel = { + // ml's Logisitic regression only supports binary classifcation currently. + if (numOfLinearPredictor == 1) { + def runWithMlLogisitcRegression(elasticNetParam: Double) = { + // Prepare the ml LogisticRegression based on our settings + val lr = new org.apache.spark.ml.classification.LogisticRegression() + lr.setRegParam(optimizer.getRegParam()) + lr.setElasticNetParam(elasticNetParam) + lr.setStandardization(useFeatureScaling) + if (userSuppliedWeights) { + val uid = Identifiable.randomUID("logreg-static") + lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel( + uid, initialWeights, 1.0)) + } + lr.setFitIntercept(addIntercept) + lr.setMaxIter(optimizer.getNumIterations()) + lr.setTol(optimizer.getConvergenceTol()) + // Convert our input into a DataFrame + val sqlContext = new SQLContext(input.context) + import sqlContext.implicits._ + val df = input.toDF() + // Determine if we should cache the DF + val handlePersistence = input.getStorageLevel == StorageLevel.NONE + // Train our model + val mlLogisticRegresionModel = lr.train(df, handlePersistence) + // convert the model + val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray) + createModel(weights, mlLogisticRegresionModel.intercept) + } + optimizer.getUpdater() match { + case x: SquaredL2Updater => runWithMlLogisitcRegression(1.0) + case x: L1Updater => runWithMlLogisitcRegression(0.0) + case _ => super.run(input, initialWeights) + } + } else { + super.run(input, initialWeights) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index efedc112d380..a5bd77e6bee9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -69,6 +69,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /* + * Get the convergence tolerance of iterations. + */ + private[mllib] def getConvergenceTol(): Double = { + this.convergenceTol + } + /** * Set the maximal number of iterations for L-BFGS. Default 100. * @deprecated use [[LBFGS#setNumIterations]] instead @@ -86,6 +93,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Get the maximum number of iterations for L-BFGS. Defaults to 100. + */ + private[mllib] def getNumIterations(): Int = { + this.maxNumIterations + } + /** * Set the regularization parameter. Default 0.0. */ @@ -94,6 +108,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Get the regularization parameter. + */ + private[mllib] def getRegParam(): Double = { + this.regParam + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for L-BFGS. @@ -113,6 +134,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Returns the updater, limited to internal use. + */ + private[mllib] def getUpdater(): Updater = { + updater + } + override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { val (weights, _) = LBFGS.runLBFGS( data, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index e60edc675c83..73da899a0edd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -140,7 +140,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * translated back to resulting model weights, so it's transparent to users. * Note: This technique is used in both libsvm and glmnet packages. Default false. */ - private var useFeatureScaling = false + private[mllib] var useFeatureScaling = false /** * The dimension of training features. @@ -196,12 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. - * + * Generate the initial weights when the user does not supply them */ - @Since("0.8.0") - def run(input: RDD[LabeledPoint]): M = { + protected def generateInitialWeights(input: RDD[LabeledPoint]): Vector = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() } @@ -217,16 +214,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always * have the intercept as part of weights to have consistent design. */ - val initialWeights = { - if (numOfLinearPredictor == 1) { - Vectors.zeros(numFeatures) - } else if (addIntercept) { - Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) - } else { - Vectors.zeros(numFeatures * numOfLinearPredictor) - } + if (numOfLinearPredictor == 1) { + Vectors.zeros(numFeatures) + } else if (addIntercept) { + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) + } else { + Vectors.zeros(numFeatures * numOfLinearPredictor) } - run(input, initialWeights) + } + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + * + */ + @Since("0.8.0") + def run(input: RDD[LabeledPoint]): M = { + run(input, generateInitialWeights(input)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index d7983f92a348..445e50d867e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -168,7 +168,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 8d14bb657215..8fef1316cd21 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -215,6 +216,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w // Test if we can correctly learn A, B where Y = logistic(A + B*X) test("logistic regression with LBFGS") { + val updaters: List[Updater] = List(new SquaredL2Updater(), new L1Updater()) + updaters.foreach(testLBFGS) + } + + private def testLBFGS(myUpdater: Updater): Unit = { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -223,7 +229,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + // Override the updater + class LogisticRegressionWithLBFGSCustomUpdater + extends LogisticRegressionWithLBFGS { + override val optimizer = + new LBFGS(new LogisticGradient, myUpdater) + } + + val lr = new LogisticRegressionWithLBFGSCustomUpdater().setIntercept(true) val model = lr.run(testRDD) @@ -396,10 +410,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) // Training data with different scales without feature standardization - // will not yield the same result in the scaled space due to poor - // convergence rate. - assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) - assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + // should still converge quickly since the model still uses standardization but + // simply modifies the regularization function. See regParamL1Fun and related + // inside of LogisticRegression + assert(modelB1.weights(0) ~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) ~== modelB3.weights(0) * 1.0E6 absTol 0.1) } test("multinomial logistic regression with LBFGS") { From e7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Jan 2016 19:29:47 -0800 Subject: [PATCH 032/131] [SPARK-12903][SPARKR] Add covar_samp and covar_pop for SparkR Add ```covar_samp``` and ```covar_pop``` for SparkR. Should we also provide ```cov``` alias for ```covar_samp```? There is ```cov``` implementation at stats.R which masks ```stats::cov``` already, but may bring to breaking API change. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10829 from yanboliang/spark-12903. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 58 +++++++++++++++++++++++ R/pkg/R/generics.R | 10 +++- R/pkg/R/stats.R | 3 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 + 5 files changed, 73 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2cc1544bef08..f194a46303e0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -35,6 +35,8 @@ exportMethods("arrange", "count", "cov", "corr", + "covar_samp", + "covar_pop", "crosstab", "describe", "dim", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9bb7876b384c..8f8651c295ee 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -275,6 +275,64 @@ setMethod("corr", signature(x = "Column"), column(jc) }) +#' cov +#' +#' Compute the sample covariance between two expressions. +#' +#' @rdname cov +#' @name cov +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' cov(df$c, df$d) +#' cov("c", "d") +#' covar_samp(df$c, df$d) +#' covar_samp("c", "d") +#' } +setMethod("cov", signature(x = "characterOrColumn"), + function(x, col2) { + stopifnot(is(class(col2), "characterOrColumn")) + covar_samp(x, col2) + }) + +#' @rdname cov +#' @name covar_samp +setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_samp", col1, col2) + column(jc) + }) + +#' covar_pop +#' +#' Compute the population covariance between two expressions. +#' +#' @rdname covar_pop +#' @name covar_pop +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' covar_pop(df$c, df$d) +#' covar_pop("c", "d") +#' } +setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_pop", col1, col2) + column(jc) + }) + #' cos #' #' Computes the cosine of the given value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 04784d51566c..2dba71abec68 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -418,12 +418,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @rdname statfunctions #' @export -setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) +setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname statfunctions #' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) +#' @rdname statfunctions +#' @export +setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) + +#' @rdname statfunctions +#' @export +setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index d17cce9c756e..2e8076843f08 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -66,8 +66,9 @@ setMethod("crosstab", #' cov <- cov(df, "title", "gender") #' } setMethod("cov", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2) { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "cov", col1, col2) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index b52a11fb1a34..7b5713720df8 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -996,6 +996,8 @@ test_that("column functions", { c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() c16 <- is.nan(c) + isnan(c) + isNaN(c) + c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") + c18 <- covar_pop(c, c1) + covar_pop("c", "c1") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From ce38a35b764397fcf561ac81de6da96579f5c13e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 20:12:34 -0800 Subject: [PATCH 033/131] [SPARK-12935][SQL] DataFrame API for Count-Min Sketch This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs. Author: Cheng Lian Closes #10911 from liancheng/cms-df-api. --- .../apache/spark/util/sketch/BloomFilter.java | 10 ++- .../spark/util/sketch/CountMinSketch.java | 26 +++--- .../spark/util/sketch/CountMinSketchImpl.java | 56 ++++++++----- sql/core/pom.xml | 5 ++ .../spark/sql/DataFrameStatFunctions.scala | 81 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 28 ++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 36 +++++++++ 7 files changed, 205 insertions(+), 37 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 00378d58518f..d392fb187ad6 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -47,10 +47,12 @@ public abstract class BloomFilter { public enum Version { /** * {@code BloomFilter} binary format version 1 (all values written in big-endian order): - * - Version number, always 1 (32 bit) - * - Total number of words of the underlying bit array (32 bit) - * - The words/longs (numWords * 64 bit) - * - Number of hash functions (32 bit) + *
    + *
  • Version number, always 1 (32 bit)
  • + *
  • Total number of words of the underlying bit array (32 bit)
  • + *
  • The words/longs (numWords * 64 bit)
  • + *
  • Number of hash functions (32 bit)
  • + *
*/ V1(1); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 00c0b1b9e2db..5692e574d4c7 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -59,16 +59,22 @@ abstract public class CountMinSketch { public enum Version { /** * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) + *
    + *
  • Version number, always 1 (32 bit)
  • + *
  • Total count of added items (64 bit)
  • + *
  • Depth (32 bit)
  • + *
  • Width (32 bit)
  • + *
  • Hash functions (depth * 64 bit)
  • + *
  • + * Count table + *
      + *
    • Row 0 (width * 64 bit)
    • + *
    • Row 1 (width * 64 bit)
    • + *
    • ...
    • + *
    • Row {@code depth - 1} (width * 64 bit)
    • + *
    + *
  • + *
*/ V1(1); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index d08809605a93..8cc29e407630 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -21,13 +21,16 @@ import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.OutputStream; +import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; -class CountMinSketchImpl extends CountMinSketch { - public static final long PRIME_MODULUS = (1L << 31) - 1; +class CountMinSketchImpl extends CountMinSketch implements Serializable { + private static final long PRIME_MODULUS = (1L << 31) - 1; private int depth; private int width; @@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch { private double eps; private double confidence; + private CountMinSketchImpl() { + } + CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; this.width = width; @@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch { initTablesWith(depth, width, seed); } - CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { - this.depth = depth; - this.width = width; - this.eps = 2.0 / width; - this.confidence = 1 - 1 / Math.pow(2, depth); - this.hashA = hashA; - this.table = table; - this.totalCount = totalCount; - } - @Override public boolean equals(Object other) { if (other == this) { @@ -325,27 +321,43 @@ public void writeTo(OutputStream out) throws IOException { } public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + CountMinSketchImpl sketch = new CountMinSketchImpl(); + sketch.readFrom0(in); + return sketch; + } + + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); - // Ignores version number - dis.readInt(); + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")"); + } - long totalCount = dis.readLong(); - int depth = dis.readInt(); - int width = dis.readInt(); + this.totalCount = dis.readLong(); + this.depth = dis.readInt(); + this.width = dis.readInt(); + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); - long hashA[] = new long[depth]; + this.hashA = new long[depth]; for (int i = 0; i < depth; ++i) { - hashA[i] = dis.readLong(); + this.hashA[i] = dis.readLong(); } - long table[][] = new long[depth][width]; + this.table = new long[depth][width]; for (int i = 0; i < depth; ++i) { for (int j = 0; j < width; ++j) { - table[i][j] = dis.readLong(); + this.table[i][j] = dis.readLong(); } } + } + + private void writeObject(ObjectOutputStream out) throws IOException { + this.writeTo(out); + } - return new CountMinSketchImpl(depth, width, totalCount, hashA, table); + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.readFrom0(in); } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 31b364f351d5..4bb55f6b7f73 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -42,6 +42,11 @@ 1.5.6 jar + + org.apache.spark + spark-sketch_2.10 + ${project.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index e66aa5f94718..465b12bb59d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.CountMinSketch /** * :: Experimental :: @@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), depth, width, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch( + colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), eps, confidence, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(depth, width, seed)) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) + } + + private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require( + colType == StringType || colType.isInstanceOf[IntegralType], + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + + singleCol.rdd.aggregate(zero)( + (sketch: CountMinSketch, row: Row) => { + sketch.add(row.get(0)) + sketch + }, + + (sketch1: CountMinSketch, sketch2: CountMinSketch) => { + sketch1.mergeInPlace(sketch2) + } + ) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ac1607ba3521..9cf94e72d34e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -35,9 +35,10 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.CountMinSketch; +import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { @@ -321,4 +322,29 @@ public void testTextLoad() { Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); } + + @Test + public void testCountMinSketch() { + DataFrame df = context.range(1000); + + CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); + Assert.assertEquals(sketch1.totalCount(), 1000); + Assert.assertEquals(sketch1.depth(), 10); + Assert.assertEquals(sketch1.width(), 20); + + CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); + Assert.assertEquals(sketch2.totalCount(), 1000); + Assert.assertEquals(sketch2.depth(), 10); + Assert.assertEquals(sketch2.width(), 20); + + CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); + Assert.assertEquals(sketch3.totalCount(), 1000); + Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); + + CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); + Assert.assertEquals(sketch4.totalCount(), 1000); + Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 63ad6c439a87..8f3ea5a2860b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql import java.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DoubleType class DataFrameStatSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { sampled.groupBy("key").count().orderBy("key"), Seq(Row(0, 6), Row(1, 11))) } + + // This test case only verifies that `DataFrame.countMinSketch()` methods do return + // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in + // `CountMinSketchSuite` in project spark-sketch. + test("countMinSketch") { + val df = sqlContext.range(1000) + + val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) + assert(sketch1.totalCount() === 1000) + assert(sketch1.depth() === 10) + assert(sketch1.width() === 20) + + val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42) + assert(sketch2.totalCount() === 1000) + assert(sketch2.depth() === 10) + assert(sketch2.width() === 20) + + val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch3.totalCount() === 1000) + assert(sketch3.relativeError() === 0.001) + assert(sketch3.confidence() === 0.99 +- 5e-3) + + val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch4.totalCount() === 1000) + assert(sketch4.relativeError() === 0.001 +- 1e04) + assert(sketch4.confidence() === 0.99 +- 5e-3) + + intercept[IllegalArgumentException] { + df.select('id cast DoubleType as 'id) + .stat + .countMinSketch('id, depth = 10, width = 20, seed = 42) + } + } } From 58f5d8c1da6feeb598aa5f74ffe1593d4839d11d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 20:30:13 -0800 Subject: [PATCH 034/131] [SPARK-12728][SQL] Integrates SQL generation with native view This PR is a follow-up of PR #10541. It integrates the newly introduced SQL generation feature with native view to make native view canonical. In this PR, a new SQL option `spark.sql.nativeView.canonical` is added. When this option and `spark.sql.nativeView` are both `true`, Spark SQL tries to handle `CREATE VIEW` DDL statements using SQL query strings generated from view definition logical plans. If we failed to map the plan to SQL, we fallback to the original native view approach. One important issue this PR fixes is that, now we can use CTE when defining a view. Originally, when native view is turned on, we wrap the view definition text with an extra `SELECT`. However, HiveQL parser doesn't allow CTE appearing as a subquery. Namely, something like this is disallowed: ```sql SELECT n FROM ( WITH w AS (SELECT 1 AS n) SELECT * FROM w ) v ``` This PR fixes this issue because the extra `SELECT` is no longer needed (also, CTE expressions are inlined as subqueries during analysis phase, thus there won't be CTE expressions in the generated SQL query string). Author: Cheng Lian Author: Yin Huai Closes #10733 from liancheng/spark-12728.integrate-sql-gen-with-native-view. --- .../scala/org/apache/spark/sql/SQLConf.scala | 10 ++ .../apache/spark/sql/test/SQLTestUtils.scala | 13 ++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 33 ++-- .../hive/execution/CreateViewAsSelect.scala | 95 ++++++++---- .../sql/hive/LogicalPlanToSQLSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 142 ++++++++++++------ 6 files changed, 200 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2d664d3ee691..c9ba6700998c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -367,6 +367,14 @@ private[spark] object SQLConf { "possible, or you may get wrong result.", isPublic = false) + val CANONICAL_NATIVE_VIEW = booleanConf("spark.sql.nativeView.canonical", + defaultValue = Some(true), + doc = "When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " + + "CREATE VIEW statement using SQL query string generated from view definition logical " + + "plan. If the logical plan doesn't have a SQL representation, we fallback to the " + + "original native view implementation.", + isPublic = false) + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), doc = "The name of internal column for storing raw/un-parsed JSON records that fail to parse.") @@ -550,6 +558,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + private[spark] def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) private[spark] def subexpressionEliminationEnabled: Boolean = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 5f73d71d4510..d48143762cac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -154,9 +154,22 @@ private[sql] trait SQLTestUtils } } + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + sqlContext.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + /** * Creates a temporary database and switches current database to it before executing `f`. This * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. */ protected def withTempDatabase(f: String => Unit): Unit = { val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 80e45d516280..a9c0e9ab7cae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -579,25 +579,24 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p - case CreateViewAsSelect(table, child, allowExisting, replace, sql) => - if (conf.nativeView) { - if (allowExisting && replace) { - throw new AnalysisException( - "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") - } + case CreateViewAsSelect(table, child, allowExisting, replace, sql) if conf.nativeView => + if (allowExisting && replace) { + throw new AnalysisException( + "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") + } - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) - execution.CreateViewAsSelect( - table.copy( - specifiedDatabase = Some(dbName), - name = tblName), - child.output, - allowExisting, - replace) - } else { - HiveNativeCommand(sql) - } + execution.CreateViewAsSelect( + table.copy( + specifiedDatabase = Some(dbName), + name = tblName), + child, + allowExisting, + replace) + + case CreateViewAsSelect(table, child, allowExisting, replace, sql) => + HiveNativeCommand(sql) case p @ CreateTableAsSelect(table, child, allowExisting) => val schema = if (table.schema.nonEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 6e288afbb4d2..31bda56e8a16 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, SQLBuilder} import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** @@ -32,10 +33,12 @@ import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} // from Hive and may not work for some cases like create view on self join. private[hive] case class CreateViewAsSelect( tableDesc: HiveTable, - childSchema: Seq[Attribute], + child: LogicalPlan, allowExisting: Boolean, orReplace: Boolean) extends RunnableCommand { + private val childSchema = child.output + assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) assert(tableDesc.viewText.isDefined) @@ -44,55 +47,83 @@ private[hive] case class CreateViewAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableIdentifier)) { - if (allowExisting) { - // view already exists, will do nothing, to keep consistent with Hive - } else if (orReplace) { - hiveContext.catalog.client.alertView(prepareTable()) - } else { + hiveContext.catalog.tableExists(tableIdentifier) match { + case true if allowExisting => + // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view + // already exists. + + case true if orReplace => + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` + hiveContext.catalog.client.alertView(prepareTable(sqlContext)) + + case true => + // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already + // exists. throw new AnalysisException(s"View $tableIdentifier already exists. " + "If you want to update the view definition, please use ALTER VIEW AS or " + "CREATE OR REPLACE VIEW AS") - } - } else { - hiveContext.catalog.client.createView(prepareTable()) + + case false => + hiveContext.catalog.client.createView(prepareTable(sqlContext)) } Seq.empty[Row] } - private def prepareTable(): HiveTable = { - // setup column types according to the schema of child. - val schema = if (tableDesc.schema == Nil) { - childSchema.map { attr => - HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) - } + private def prepareTable(sqlContext: SQLContext): HiveTable = { + val expandedText = if (sqlContext.conf.canonicalView) { + rebuildViewQueryString(sqlContext).getOrElse(wrapViewTextWithSelect) } else { - childSchema.zip(tableDesc.schema).map { case (attr, col) => - HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + wrapViewTextWithSelect + } + + val viewSchema = { + if (tableDesc.schema.isEmpty) { + childSchema.map { attr => + HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + } + } else { + childSchema.zip(tableDesc.schema).map { case (attr, col) => + HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + } } } - val columnNames = childSchema.map(f => verbose(f.name)) + tableDesc.copy(schema = viewSchema, viewText = Some(expandedText)) + } + private def wrapViewTextWithSelect: String = { // When user specified column names for view, we should create a project to do the renaming. // When no column name specified, we still need to create a project to declare the columns // we need, to make us more robust to top level `*`s. - val projectList = if (tableDesc.schema == Nil) { - columnNames.mkString(", ") - } else { - columnNames.zip(tableDesc.schema.map(f => verbose(f.name))).map { - case (name, alias) => s"$name AS $alias" - }.mkString(", ") + val viewOutput = { + val columnNames = childSchema.map(f => quote(f.name)) + if (tableDesc.schema.isEmpty) { + columnNames.mkString(", ") + } else { + columnNames.zip(tableDesc.schema.map(f => quote(f.name))).map { + case (name, alias) => s"$name AS $alias" + }.mkString(", ") + } } - val viewName = verbose(tableDesc.name) - - val expandedText = s"SELECT $projectList FROM (${tableDesc.viewText.get}) $viewName" + val viewText = tableDesc.viewText.get + val viewName = quote(tableDesc.name) + s"SELECT $viewOutput FROM ($viewText) $viewName" + } - tableDesc.copy(schema = schema, viewText = Some(expandedText)) + private def rebuildViewQueryString(sqlContext: SQLContext): Option[String] = { + val logicalPlan = if (tableDesc.schema.isEmpty) { + child + } else { + val projectList = childSchema.zip(tableDesc.schema).map { + case (attr, col) => Alias(attr, col.name)() + } + sqlContext.executePlan(Project(projectList, child)).analyzed + } + new SQLBuilder(logicalPlan, sqlContext).toSQL } // escape backtick with double-backtick in column name and wrap it with backtick. - private def verbose(name: String) = s"`${name.replaceAll("`", "``")}`" + private def quote(name: String) = s"`${name.replaceAll("`", "``")}`" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 261a4746f428..1f731db26f38 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -147,7 +147,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // TODO Enable this // Query plans transformed by DistinctAggregationRewriter are not recognized yet - ignore("distinct and non-distinct aggregation") { + ignore("multi-distinct columns") { checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 683008960aa2..9e53d8a81e75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1319,67 +1319,119 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("correctly handle CREATE OR REPLACE VIEW") { - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt", "jt2") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") - checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + Seq(true, false).foreach { enabled => + val prefix = (if (enabled) "With" else "Without") + " canonical native view: " + test(s"$prefix correctly handle CREATE OR REPLACE VIEW") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + + val e = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + } + assert(e.message.contains("not allowed to define a view")) + } + } + } - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") - // make sure the view has been changed. - checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + test(s"$prefix correctly handle ALTER VIEW") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt", "jt2") { + withView("testView") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("ALTER VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + } + } + } + } - sql("DROP VIEW testView") + test(s"$prefix create hive view for json table") { + // json table is not hive-compatible, make sure the new flag fix it. + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt") { + withView("testView") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + } + } + } + } - val e = intercept[AnalysisException] { - sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + test(s"$prefix create hive view for partitioned parquet table") { + // partitioned parquet table is not hive-compatible, make sure the new flag fix it. + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("parTable") { + withView("testView") { + val df = Seq(1 -> "a").toDF("i", "j") + df.write.format("parquet").partitionBy("i").saveAsTable("parTable") + sql("CREATE VIEW testView AS SELECT i, j FROM parTable") + checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) + } } - assert(e.message.contains("not allowed to define a view")) } } } - test("correctly handle ALTER VIEW") { - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt", "jt2") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") - - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("ALTER VIEW testView AS SELECT * FROM jt2") - // make sure the view has been changed. - checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) - - sql("DROP VIEW testView") + test("CTE within view") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withView("cte_view") { + sql("CREATE VIEW cte_view AS WITH w AS (SELECT 1 AS n) SELECT n FROM w") + checkAnswer(sql("SELECT * FROM cte_view"), Row(1)) } } } - test("create hive view for json table") { - // json table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") - checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) - sql("DROP VIEW testView") + test("Using view after switching current database") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withView("v") { + sql("CREATE VIEW v AS SELECT * FROM src") + withTempDatabase { db => + activateDatabase(db) { + // Should look up table `src` in database `default`. + checkAnswer(sql("SELECT * FROM default.v"), sql("SELECT * FROM default.src")) + + // The new `src` table shouldn't be scanned. + sql("CREATE TABLE src(key INT, value STRING)") + checkAnswer(sql("SELECT * FROM default.v"), sql("SELECT * FROM default.src")) + } + } } } } - test("create hive view for partitioned parquet table") { - // partitioned parquet table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("parTable") { - val df = Seq(1 -> "a").toDF("i", "j") - df.write.format("parquet").partitionBy("i").saveAsTable("parTable") - sql("CREATE VIEW testView AS SELECT i, j FROM parTable") - checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) - sql("DROP VIEW testView") + test("Using view after adding more columns") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withTable("add_col") { + sqlContext.range(10).write.saveAsTable("add_col") + withView("v") { + sql("CREATE VIEW v AS SELECT * FROM add_col") + sqlContext.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + checkAnswer(sql("SELECT * FROM v"), sqlContext.range(10)) + } } } } From bae3c9a4eb0c320999e5dbafd62692c12823e07d Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Tue, 26 Jan 2016 21:14:39 -0800 Subject: [PATCH 035/131] [SPARK-12967][NETTY] Avoid NettyRpc error message during sparkContext shutdown If there's an RPC issue while sparkContext is alive but stopped (which would happen only when executing SparkContext.stop), log a warning instead. This is a common occurrence. vanzin Author: Nishkam Ravi Author: nishkamravi2 Closes #10881 from nishkamravi2/master_netty. --- .../spark/rpc/RpcEnvStoppedException.scala | 20 +++++++++++++++++++ .../apache/spark/rpc/netty/Dispatcher.scala | 4 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 6 +++++- .../org/apache/spark/rpc/netty/Outbox.scala | 7 +++++-- 4 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala new file mode 100644 index 000000000000..c296cc23f12b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +private[rpc] class RpcEnvStoppedException() + extends IllegalStateException("RpcEnv already stopped.") diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 19259e0e800c..6ceff2c07399 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -106,7 +106,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e)) + postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}")) } } @@ -156,7 +156,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block val error = if (stopped) { - new IllegalStateException("RpcEnv already stopped.") + new RpcEnvStoppedException() } else { new SparkException(s"Could not find $endpointName or it has been stopped.") } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 9ae74d9d7b89..89eda857e622 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -182,7 +182,11 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { // Message to a local RPC endpoint. - dispatcher.postOneWayMessage(message) + try { + dispatcher.postOneWayMessage(message) + } catch { + case e: RpcEnvStoppedException => logWarning(e.getMessage) + } } else { // Message to a remote RPC endpoint. postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 2316ebe347bb..9fd64e853575 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEnvStoppedException} private[netty] sealed trait OutboxMessage { @@ -43,7 +43,10 @@ private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends Outbo } override def onFailure(e: Throwable): Unit = { - logWarning(s"Failed to send one-way RPC.", e) + e match { + case e1: RpcEnvStoppedException => logWarning(e1.getMessage) + case e1: Throwable => logWarning(s"Failed to send one-way RPC.", e1) + } } } From 4db255c7aa756daa224d61905db745b6bccc9173 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 21:16:56 -0800 Subject: [PATCH 036/131] [SPARK-12780] Inconsistency returning value of ML python models' properties https://issues.apache.org/jira/browse/SPARK-12780 Author: Xusen Yin Closes #10724 from yinxusen/SPARK-12780. --- python/pyspark/ml/feature.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 22081233b04d..d017a231886c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1323,7 +1323,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] - >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels) >>> itd = inverter.transform(td) >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) @@ -1365,13 +1365,14 @@ class StringIndexerModel(JavaModel): .. versionadded:: 1.4.0 """ + @property @since("1.5.0") def labels(self): """ Ordered list of labels, corresponding to indices to be assigned. """ - return self._java_obj.labels + return self._call_java("labels") @inherit_doc From 90b0e562406a8bac529e190472e7f5da4030bf5c Mon Sep 17 00:00:00 2001 From: BenFradet Date: Wed, 27 Jan 2016 09:27:11 +0000 Subject: [PATCH 037/131] [SPARK-12983][CORE][DOC] Correct metrics.properties.template There are some typos or plain unintelligible sentences in the metrics template. Author: BenFradet Closes #10902 from BenFradet/SPARK-12983. --- conf/metrics.properties.template | 71 +++++++++++++++++--------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index d6962e0da2f3..8a4f4e48335b 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -57,39 +57,41 @@ # added to Java properties using -Dspark.metrics.conf=xxx if you want to # customize metrics system. You can also put the file in ${SPARK_HOME}/conf # and it will be loaded automatically. -# 5. MetricsServlet is added by default as a sink in master, worker and client -# driver, you can send http request "/metrics/json" to get a snapshot of all the -# registered metrics in json format. For master, requests "/metrics/master/json" and -# "/metrics/applications/json" can be sent seperately to get metrics snapshot of -# instance master and applications. MetricsServlet may not be configured by self. -# +# 5. The MetricsServlet sink is added by default as a sink in the master, +# worker and driver, and you can send HTTP requests to the "/metrics/json" +# endpoint to get a snapshot of all the registered metrics in JSON format. +# For master, requests to the "/metrics/master/json" and +# "/metrics/applications/json" endpoints can be sent separately to get +# metrics snapshots of the master instance and applications. This +# MetricsServlet does not have to be configured. ## List of available common sources and their properties. # org.apache.spark.metrics.source.JvmSource -# Note: Currently, JvmSource is the only available common source -# to add additionaly to an instance, to enable this, -# set the "class" option to its fully qulified class name (see examples below) +# Note: Currently, JvmSource is the only available common source. +# It can be added to an instance by setting the "class" option to its +# fully qualified class name (see examples below). ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink # Name: Default: Description: # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # org.apache.spark.metrics.sink.CSVSink # Name: Default: Description: # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # directory /tmp Where to store CSV files # org.apache.spark.metrics.sink.GangliaSink # Name: Default: Description: -# host NONE Hostname or multicast group of Ganglia server -# port NONE Port of Ganglia server(s) +# host NONE Hostname or multicast group of the Ganglia server, +# must be set +# port NONE Port of the Ganglia server(s), must be set # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # ttl 1 TTL of messages sent by Ganglia # mode multicast Ganglia network mode ('unicast' or 'multicast') @@ -98,19 +100,21 @@ # org.apache.spark.metrics.sink.MetricsServlet # Name: Default: Description: # path VARIES* Path prefix from the web server root -# sample false Whether to show entire set of samples for histograms ('false' or 'true') +# sample false Whether to show entire set of samples for histograms +# ('false' or 'true') # -# * Default path is /metrics/json for all instances except the master. The master has two paths: +# * Default path is /metrics/json for all instances except the master. The +# master has two paths: # /metrics/applications/json # App information # /metrics/master/json # Master information # org.apache.spark.metrics.sink.GraphiteSink # Name: Default: Description: -# host NONE Hostname of Graphite server -# port NONE Port of Graphite server +# host NONE Hostname of the Graphite server, must be set +# port NONE Port of the Graphite server, must be set # period 10 Poll period -# unit seconds Units of poll period -# prefix EMPTY STRING Prefix to prepend to metric name +# unit seconds Unit of the poll period +# prefix EMPTY STRING Prefix to prepend to every metric's name # protocol tcp Protocol ("tcp" or "udp") to use ## Examples @@ -120,42 +124,42 @@ # Enable ConsoleSink for all instances by class name #*.sink.console.class=org.apache.spark.metrics.sink.ConsoleSink -# Polling period for ConsoleSink +# Polling period for the ConsoleSink #*.sink.console.period=10 - +# Unit of the polling period for the ConsoleSink #*.sink.console.unit=seconds -# Master instance overlap polling period +# Polling period for the ConsoleSink specific for the master instance #master.sink.console.period=15 - +# Unit of the polling period for the ConsoleSink specific for the master +# instance #master.sink.console.unit=seconds -# Enable CsvSink for all instances +# Enable CsvSink for all instances by class name #*.sink.csv.class=org.apache.spark.metrics.sink.CsvSink -# Polling period for CsvSink +# Polling period for the CsvSink #*.sink.csv.period=1 - +# Unit of the polling period for the CsvSink #*.sink.csv.unit=minutes # Polling directory for CsvSink #*.sink.csv.directory=/tmp/ -# Worker instance overlap polling period +# Polling period for the CsvSink specific for the worker instance #worker.sink.csv.period=10 - +# Unit of the polling period for the CsvSink specific for the worker instance #worker.sink.csv.unit=minutes # Enable Slf4jSink for all instances by class name #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink -# Polling period for Slf4JSink +# Polling period for the Slf4JSink #*.sink.slf4j.period=1 - +# Unit of the polling period for the Slf4jSink #*.sink.slf4j.unit=minutes - -# Enable jvm source for instance master, worker, driver and executor +# Enable JvmSource for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource #worker.source.jvm.class=org.apache.spark.metrics.source.JvmSource @@ -163,4 +167,3 @@ #driver.source.jvm.class=org.apache.spark.metrics.source.JvmSource #executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource - From 093291cf9b8729c0bd057cf67aed840b11f8c94a Mon Sep 17 00:00:00 2001 From: Andrew Date: Wed, 27 Jan 2016 09:31:44 +0000 Subject: [PATCH 038/131] [SPARK-1680][DOCS] Explain environment variables for running on YARN in cluster mode JIRA 1680 added a property called spark.yarn.appMasterEnv. This PR draws users' attention to this special case by adding an explanation in configuration.html#environment-variables Author: Andrew Closes #10869 from weineran/branch-yarn-docs. --- docs/configuration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index d2a2f1052405..74a8fb5d35a6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1643,6 +1643,8 @@ to use on each machine and maximum memory. Since `spark-env.sh` is a shell script, some of these can be set programmatically -- for example, you might compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface. +Note: When running Spark on YARN in `cluster` mode, environment variables need to be set using the `spark.yarn.appMasterEnv.[EnvironmentVariableName]` property in your `conf/spark-defaults.conf` file. Environment variables that are set in `spark-env.sh` will not be reflected in the YARN Application Master process in `cluster` mode. See the [YARN-related Spark Properties](running-on-yarn.html#spark-properties) for more information. + # Configuring Logging Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a From 41f0c85f9be264103c066935e743f59caf0fe268 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 Jan 2016 08:32:13 -0800 Subject: [PATCH 039/131] [SPARK-13023][PROJECT INFRA] Fix handling of root module in modules_to_test() There's a minor bug in how we handle the `root` module in the `modules_to_test()` function in `dev/run-tests.py`: since `root` now depends on `build` (since every test needs to run on any build test), we now need to check for the presence of root in `modules_to_test` instead of `changed_modules`. Author: Josh Rosen Closes #10933 from JoshRosen/build-module-fix. --- dev/run-tests.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index c78a66f6aa54..6febbf108900 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -104,6 +104,8 @@ def determine_modules_to_test(changed_modules): >>> [x.name for x in determine_modules_to_test([modules.root])] ['root'] + >>> [x.name for x in determine_modules_to_test([modules.build])] + ['root'] >>> [x.name for x in determine_modules_to_test([modules.graphx])] ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] @@ -111,15 +113,13 @@ def determine_modules_to_test(changed_modules): ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ - # If we're going to have to run all of the tests, then we can just short-circuit - # and return 'root'. No module depends on root, so if it appears then it will be - # in changed_modules. - if modules.root in changed_modules: - return [modules.root] modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) modules_to_test = modules_to_test.union(set(changed_modules)) + # If we need to run all of the tests, then we should short-circuit and return 'root' + if modules.root in modules_to_test: + return [modules.root] return toposort_flatten( {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True) From edd473751b59b55fa3daede5ed7bc19ea8bd7170 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Wed, 27 Jan 2016 09:55:10 -0800 Subject: [PATCH 040/131] [SPARK-10847][SQL][PYSPARK] Pyspark - DataFrame - Optional Metadata with `None` triggers cryptic failure The error message is now changed from "Do not support type class scala.Tuple2." to "Do not support type class org.json4s.JsonAST$JNull$" to be more informative about what is not supported. Also, StructType metadata now handles JNull correctly, i.e., {'a': None}. test_metadata_null is added to tests.py to show the fix works. Author: Jason Lee Closes #8969 from jasoncl/SPARK-10847. --- python/pyspark/sql/tests.py | 7 +++++++ .../main/scala/org/apache/spark/sql/types/Metadata.scala | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7593b991a780..410efbafe079 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -747,6 +747,13 @@ def test_struct_type(self): except ValueError: self.assertEqual(1, 1) + def test_metadata_null(self): + from pyspark.sql.types import StructType, StringType, StructField + schema = StructType([StructField("f1", StringType(), True, None), + StructField("f2", StringType(), True, {'a': None})]) + rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) + self.sqlCtx.createDataFrame(rdd, schema) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 6ee24ee0c191..9e0f9943bc63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -156,7 +156,9 @@ object Metadata { throw new RuntimeException(s"Do not support array of type ${other.getClass}.") } } - case other => + case (key, JNull) => + builder.putNull(key) + case (key, other) => throw new RuntimeException(s"Do not support type ${other.getClass}.") } builder.build() @@ -229,6 +231,9 @@ class MetadataBuilder { this } + /** Puts a null. */ + def putNull(key: String): this.type = put(key, null) + /** Puts a Long. */ def putLong(key: String, value: Long): this.type = put(key, value) From 87abcf7df921a5937fdb2bae8bfb30bfabc4970a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 27 Jan 2016 11:15:48 -0800 Subject: [PATCH 041/131] [SPARK-12895][SPARK-12896] Migrate TaskMetrics to accumulators The high level idea is that instead of having the executors send both accumulator updates and TaskMetrics, we should have them send only accumulator updates. This eliminates the need to maintain both code paths since one can be implemented in terms of the other. This effort is split into two parts: **SPARK-12895: Implement TaskMetrics using accumulators.** TaskMetrics is basically just a bunch of accumulable fields. This patch makes TaskMetrics a syntactic wrapper around a collection of accumulators so we don't need to send TaskMetrics from the executors to the driver. **SPARK-12896: Send only accumulator updates to the driver.** Now that TaskMetrics are expressed in terms of accumulators, we can capture all TaskMetrics values if we just send accumulator updates from the executors to the driver. This completes the parent issue SPARK-10620. While an effort has been made to preserve as much of the public API as possible, there were a few known breaking DeveloperApi changes that would be very awkward to maintain. I will gather the full list shortly and post it here. Note: This was once part of #10717. This patch is split out into its own patch from there to make it easier for others to review. Other smaller pieces of already been merged into master. Author: Andrew Or Closes #10835 from andrewor14/task-metrics-use-accums. --- .../shuffle/sort/UnsafeShuffleWriter.java | 8 +- .../scala/org/apache/spark/Accumulable.scala | 86 ++- .../scala/org/apache/spark/Accumulator.scala | 101 +++- .../scala/org/apache/spark/Aggregator.scala | 3 +- .../org/apache/spark/HeartbeatReceiver.scala | 6 +- .../apache/spark/InternalAccumulator.scala | 199 ++++++- .../scala/org/apache/spark/TaskContext.scala | 19 +- .../org/apache/spark/TaskContextImpl.scala | 30 +- .../org/apache/spark/TaskEndReason.scala | 29 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 8 + .../org/apache/spark/executor/Executor.scala | 45 +- .../apache/spark/executor/InputMetrics.scala | 81 ++- .../apache/spark/executor/OutputMetrics.scala | 71 ++- .../spark/executor/ShuffleReadMetrics.scala | 104 ++-- .../spark/executor/ShuffleWriteMetrics.scala | 62 +- .../apache/spark/executor/TaskMetrics.scala | 370 +++++++----- .../org/apache/spark/rdd/CoGroupedRDD.scala | 3 +- .../org/apache/spark/rdd/HadoopRDD.scala | 24 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 22 +- .../spark/scheduler/AccumulableInfo.scala | 55 +- .../apache/spark/scheduler/DAGScheduler.scala | 102 ++-- .../spark/scheduler/DAGSchedulerEvent.scala | 5 +- .../apache/spark/scheduler/ResultTask.scala | 8 +- .../spark/scheduler/ShuffleMapTask.scala | 8 +- .../spark/scheduler/SparkListener.scala | 7 +- .../org/apache/spark/scheduler/Stage.scala | 4 +- .../org/apache/spark/scheduler/Task.scala | 41 +- .../apache/spark/scheduler/TaskResult.scala | 28 +- .../spark/scheduler/TaskResultGetter.scala | 27 +- .../spark/scheduler/TaskScheduler.scala | 6 +- .../spark/scheduler/TaskSchedulerImpl.scala | 13 +- .../spark/scheduler/TaskSetManager.scala | 14 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../status/api/v1/AllStagesResource.scala | 3 +- .../spark/ui/jobs/JobProgressListener.scala | 18 +- .../org/apache/spark/ui/jobs/StagePage.scala | 36 +- .../org/apache/spark/util/JsonProtocol.scala | 124 +++- .../util/collection/ExternalSorter.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 2 - .../org/apache/spark/AccumulatorSuite.scala | 324 ++++++----- .../ExecutorAllocationManagerSuite.scala | 2 +- .../apache/spark/HeartbeatReceiverSuite.scala | 6 +- .../spark/InternalAccumulatorSuite.scala | 331 +++++++++++ .../org/apache/spark/SparkFunSuite.scala | 2 + .../spark/executor/TaskMetricsSuite.scala | 540 +++++++++++++++++- .../spark/memory/MemoryTestingUtils.scala | 3 +- .../spark/scheduler/DAGSchedulerSuite.scala | 281 ++++----- .../spark/scheduler/ReplayListenerSuite.scala | 8 +- .../spark/scheduler/TaskContextSuite.scala | 56 +- .../scheduler/TaskResultGetterSuite.scala | 67 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 36 +- .../org/apache/spark/ui/StagePageSuite.scala | 11 +- .../ui/jobs/JobProgressListenerSuite.scala | 26 +- .../apache/spark/util/JsonProtocolSuite.scala | 515 ++++++++++++----- project/MimaExcludes.scala | 9 + .../org/apache/spark/sql/execution/Sort.scala | 8 +- .../TungstenAggregationIterator.scala | 6 +- .../datasources/SqlNewHadoopRDD.scala | 22 +- .../execution/joins/BroadcastHashJoin.scala | 3 +- .../joins/BroadcastHashOuterJoin.scala | 3 +- .../joins/BroadcastLeftSemiJoinHash.scala | 3 +- .../sql/execution/metric/SQLMetrics.scala | 6 +- .../spark/sql/execution/ui/SQLListener.scala | 22 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/execution/ReferenceSort.scala | 3 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 3 +- .../UnsafeKVExternalSorterSuite.scala | 3 +- .../columnar/PartitionBatchPruningSuite.scala | 38 +- .../sql/execution/ui/SQLListenerSuite.scala | 34 +- .../sql/util/DataFrameCallbackSuite.scala | 2 +- 70 files changed, 3012 insertions(+), 1141 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index d3d79a27ea1c..128a82579b80 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -444,13 +444,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { - // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) - Map> internalAccumulators = - taskContext.internalMetricsToAccumulators(); - if (internalAccumulators != null) { - internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) - .add(getPeakMemoryUsedBytes()); - } + taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); if (stopping) { return Option.apply(null); diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index a456d420b8d6..bde136141f40 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -35,40 +35,67 @@ import org.apache.spark.util.Utils * [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are * accumulating a set. You will add items to the set, and you will union two sets together. * + * All accumulators created on the driver to be used on the executors must be registered with + * [[Accumulators]]. This is already done automatically for accumulators created by the user. + * Internal accumulators must be explicitly registered by the caller. + * + * Operations are not thread-safe. + * + * @param id ID of this accumulator; for internal use only. * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` * @param name human-readable name for use in Spark's web UI * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be * thread safe so that they can be reported correctly. + * @param countFailedValues whether to accumulate values from failed tasks. This is set to true + * for system and time metrics like serialization time or bytes spilled, + * and false for things with absolute values like number of input rows. + * This should be used for internal metrics only. * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] private[spark] ( - initialValue: R, +class Accumulable[R, T] private ( + val id: Long, + @transient initialValue: R, param: AccumulableParam[R, T], val name: Option[String], - internal: Boolean) + internal: Boolean, + private[spark] val countFailedValues: Boolean) extends Serializable { private[spark] def this( - @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { - this(initialValue, param, None, internal) + initialValue: R, + param: AccumulableParam[R, T], + name: Option[String], + internal: Boolean, + countFailedValues: Boolean) = { + this(Accumulators.newId(), initialValue, param, name, internal, countFailedValues) } - def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = - this(initialValue, param, name, false) + private[spark] def this( + initialValue: R, + param: AccumulableParam[R, T], + name: Option[String], + internal: Boolean) = { + this(initialValue, param, name, internal, false /* countFailedValues */) + } - def this(@transient initialValue: R, param: AccumulableParam[R, T]) = - this(initialValue, param, None) + def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false /* internal */) - val id: Long = Accumulators.newId + def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) - @volatile @transient private var value_ : R = initialValue // Current value on master - val zero = param.zero(initialValue) // Zero value to be passed to workers + @volatile @transient private var value_ : R = initialValue // Current value on driver + val zero = param.zero(initialValue) // Zero value to be passed to executors private var deserialized = false - Accumulators.register(this) + // In many places we create internal accumulators without access to the active context cleaner, + // so if we register them here then we may never unregister these accumulators. To avoid memory + // leaks, we require the caller to explicitly register internal accumulators elsewhere. + if (!internal) { + Accumulators.register(this) + } /** * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver @@ -77,6 +104,17 @@ class Accumulable[R, T] private[spark] ( */ private[spark] def isInternal: Boolean = internal + /** + * Return a copy of this [[Accumulable]]. + * + * The copy will have the same ID as the original and will not be registered with + * [[Accumulators]] again. This method exists so that the caller can avoid passing the + * same mutable instance around. + */ + private[spark] def copy(): Accumulable[R, T] = { + new Accumulable[R, T](id, initialValue, param, name, internal, countFailedValues) + } + /** * Add more data to this accumulator / accumulable * @param term the data to add @@ -106,7 +144,7 @@ class Accumulable[R, T] private[spark] ( def merge(term: R) { value_ = param.addInPlace(value_, term)} /** - * Access the accumulator's current value; only allowed on master. + * Access the accumulator's current value; only allowed on driver. */ def value: R = { if (!deserialized) { @@ -128,7 +166,7 @@ class Accumulable[R, T] private[spark] ( def localValue: R = value_ /** - * Set the accumulator's value; only allowed on master. + * Set the accumulator's value; only allowed on driver. */ def value_= (newValue: R) { if (!deserialized) { @@ -139,22 +177,24 @@ class Accumulable[R, T] private[spark] ( } /** - * Set the accumulator's value; only allowed on master + * Set the accumulator's value. For internal use only. */ - def setValue(newValue: R) { - this.value = newValue - } + def setValue(newValue: R): Unit = { value_ = newValue } + + /** + * Set the accumulator's value. For internal use only. + */ + private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) } // Called by Java when deserializing an object private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() value_ = zero deserialized = true + // Automatically register the accumulator when it is deserialized with the task closure. - // - // Note internal accumulators sent with task are deserialized before the TaskContext is created - // and are registered in the TaskContext constructor. Other internal accumulators, such SQL - // metrics, still need to register here. + // This is for external accumulators and internal ones that do not represent task level + // metrics, e.g. internal SQL metrics, which are per-operator. val taskContext = TaskContext.get() if (taskContext != null) { taskContext.registerAccumulator(this) diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 007136e6ae34..558bd447e22c 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -17,9 +17,14 @@ package org.apache.spark -import scala.collection.{mutable, Map} +import java.util.concurrent.atomic.AtomicLong +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable import scala.ref.WeakReference +import org.apache.spark.storage.{BlockId, BlockStatus} + /** * A simpler value of [[Accumulable]] where the result type being accumulated is the same @@ -49,14 +54,18 @@ import scala.ref.WeakReference * * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `T` + * @param name human-readable name associated with this accumulator + * @param internal whether this accumulator is used internally within Spark only + * @param countFailedValues whether to accumulate values from failed tasks * @tparam T result type */ class Accumulator[T] private[spark] ( @transient private[spark] val initialValue: T, param: AccumulatorParam[T], name: Option[String], - internal: Boolean) - extends Accumulable[T, T](initialValue, param, name, internal) { + internal: Boolean, + override val countFailedValues: Boolean = false) + extends Accumulable[T, T](initialValue, param, name, internal, countFailedValues) { def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { this(initialValue, param, name, false) @@ -75,43 +84,63 @@ private[spark] object Accumulators extends Logging { * This global map holds the original accumulator objects that are created on the driver. * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. + * TODO: Don't use a global map; these should be tied to a SparkContext at the very least. */ + @GuardedBy("Accumulators") val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() - private var lastId: Long = 0 + private val nextId = new AtomicLong(0L) - def newId(): Long = synchronized { - lastId += 1 - lastId - } + /** + * Return a globally unique ID for a new [[Accumulable]]. + * Note: Once you copy the [[Accumulable]] the ID is no longer unique. + */ + def newId(): Long = nextId.getAndIncrement + /** + * Register an [[Accumulable]] created on the driver such that it can be used on the executors. + * + * All accumulators registered here can later be used as a container for accumulating partial + * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does. + * Note: if an accumulator is registered here, it should also be registered with the active + * context cleaner for cleanup so as to avoid memory leaks. + * + * If an [[Accumulable]] with the same ID was already registered, this does nothing instead + * of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct + * [[org.apache.spark.executor.TaskMetrics]] from accumulator updates. + */ def register(a: Accumulable[_, _]): Unit = synchronized { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) + if (!originals.contains(a.id)) { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) + } } - def remove(accId: Long) { - synchronized { - originals.remove(accId) - } + /** + * Unregister the [[Accumulable]] with the given ID, if any. + */ + def remove(accId: Long): Unit = synchronized { + originals.remove(accId) } - // Add values to the original accumulators with some given IDs - def add(values: Map[Long, Any]): Unit = synchronized { - for ((id, value) <- values) { - if (originals.contains(id)) { - // Since we are now storing weak references, we must check whether the underlying data - // is valid. - originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value - case None => - throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") - } - } else { - logWarning(s"Ignoring accumulator update for unknown accumulator id $id") + /** + * Return the [[Accumulable]] registered with the given ID, if any. + */ + def get(id: Long): Option[Accumulable[_, _]] = synchronized { + originals.get(id).map { weakRef => + // Since we are storing weak references, we must check whether the underlying data is valid. + weakRef.get.getOrElse { + throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id") } } } + /** + * Clear all registered [[Accumulable]]s. For testing only. + */ + def clear(): Unit = synchronized { + originals.clear() + } + } @@ -156,5 +185,23 @@ object AccumulatorParam { def zero(initialValue: Float): Float = 0f } - // TODO: Add AccumulatorParams for other types, e.g. lists and strings + // Note: when merging values, this param just adopts the newer value. This is used only + // internally for things that shouldn't really be accumulated across tasks, like input + // read method, which should be the same across all tasks in the same stage. + private[spark] object StringAccumulatorParam extends AccumulatorParam[String] { + def addInPlace(t1: String, t2: String): String = t2 + def zero(initialValue: String): String = "" + } + + // Note: this is expensive as it makes a copy of the list every time the caller adds an item. + // A better way to use this is to first accumulate the values yourself then them all at once. + private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] { + def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2 + def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T] + } + + // For the internal metric that records what blocks are updated in a particular task + private[spark] object UpdatedBlockStatusesAccumulatorParam + extends ListAccumulatorParam[(BlockId, BlockStatus)] + } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 62629000cfc2..e493d9a3cf9c 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -57,8 +57,7 @@ case class Aggregator[K, V, C] ( Option(context).foreach { c => c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - c.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + c.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) } } } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index e03977828b86..45b20c0e8d60 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} */ private[spark] case class Heartbeat( executorId: String, - taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates blockManagerId: BlockManagerId) /** @@ -119,14 +119,14 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) context.reply(true) // Messages received from executors - case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId) => if (scheduler != null) { if (executorLastSeen.contains(executorId)) { executorLastSeen(executorId) = clock.getTimeMillis() eventLoopThread.submit(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) + executorId, accumUpdates, blockManagerId) val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) context.reply(response) } diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 6ea997c079f3..c191122c0630 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -17,23 +17,169 @@ package org.apache.spark +import org.apache.spark.storage.{BlockId, BlockStatus} -// This is moved to its own file because many more things will be added to it in SPARK-10620. + +/** + * A collection of fields and methods concerned with internal accumulators that represent + * task level metrics. + */ private[spark] object InternalAccumulator { - val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" - val TEST_ACCUMULATOR = "testAccumulator" - - // For testing only. - // This needs to be a def since we don't want to reuse the same accumulator across stages. - private def maybeTestAccumulator: Option[Accumulator[Long]] = { - if (sys.props.contains("spark.testing")) { - Some(new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) - } else { - None + + import AccumulatorParam._ + + // Prefixes used in names of internal task level metrics + val METRICS_PREFIX = "internal.metrics." + val SHUFFLE_READ_METRICS_PREFIX = METRICS_PREFIX + "shuffle.read." + val SHUFFLE_WRITE_METRICS_PREFIX = METRICS_PREFIX + "shuffle.write." + val OUTPUT_METRICS_PREFIX = METRICS_PREFIX + "output." + val INPUT_METRICS_PREFIX = METRICS_PREFIX + "input." + + // Names of internal task level metrics + val EXECUTOR_DESERIALIZE_TIME = METRICS_PREFIX + "executorDeserializeTime" + val EXECUTOR_RUN_TIME = METRICS_PREFIX + "executorRunTime" + val RESULT_SIZE = METRICS_PREFIX + "resultSize" + val JVM_GC_TIME = METRICS_PREFIX + "jvmGCTime" + val RESULT_SERIALIZATION_TIME = METRICS_PREFIX + "resultSerializationTime" + val MEMORY_BYTES_SPILLED = METRICS_PREFIX + "memoryBytesSpilled" + val DISK_BYTES_SPILLED = METRICS_PREFIX + "diskBytesSpilled" + val PEAK_EXECUTION_MEMORY = METRICS_PREFIX + "peakExecutionMemory" + val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses" + val TEST_ACCUM = METRICS_PREFIX + "testAccumulator" + + // scalastyle:off + + // Names of shuffle read metrics + object shuffleRead { + val REMOTE_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "remoteBlocksFetched" + val LOCAL_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "localBlocksFetched" + val REMOTE_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesRead" + val LOCAL_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "localBytesRead" + val FETCH_WAIT_TIME = SHUFFLE_READ_METRICS_PREFIX + "fetchWaitTime" + val RECORDS_READ = SHUFFLE_READ_METRICS_PREFIX + "recordsRead" + } + + // Names of shuffle write metrics + object shuffleWrite { + val BYTES_WRITTEN = SHUFFLE_WRITE_METRICS_PREFIX + "bytesWritten" + val RECORDS_WRITTEN = SHUFFLE_WRITE_METRICS_PREFIX + "recordsWritten" + val WRITE_TIME = SHUFFLE_WRITE_METRICS_PREFIX + "writeTime" + } + + // Names of output metrics + object output { + val WRITE_METHOD = OUTPUT_METRICS_PREFIX + "writeMethod" + val BYTES_WRITTEN = OUTPUT_METRICS_PREFIX + "bytesWritten" + val RECORDS_WRITTEN = OUTPUT_METRICS_PREFIX + "recordsWritten" + } + + // Names of input metrics + object input { + val READ_METHOD = INPUT_METRICS_PREFIX + "readMethod" + val BYTES_READ = INPUT_METRICS_PREFIX + "bytesRead" + val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead" + } + + // scalastyle:on + + /** + * Create an internal [[Accumulator]] by name, which must begin with [[METRICS_PREFIX]]. + */ + def create(name: String): Accumulator[_] = { + require(name.startsWith(METRICS_PREFIX), + s"internal accumulator name must start with '$METRICS_PREFIX': $name") + getParam(name) match { + case p @ LongAccumulatorParam => newMetric[Long](0L, name, p) + case p @ IntAccumulatorParam => newMetric[Int](0, name, p) + case p @ StringAccumulatorParam => newMetric[String]("", name, p) + case p @ UpdatedBlockStatusesAccumulatorParam => + newMetric[Seq[(BlockId, BlockStatus)]](Seq(), name, p) + case p => throw new IllegalArgumentException( + s"unsupported accumulator param '${p.getClass.getSimpleName}' for metric '$name'.") + } + } + + /** + * Get the [[AccumulatorParam]] associated with the internal metric name, + * which must begin with [[METRICS_PREFIX]]. + */ + def getParam(name: String): AccumulatorParam[_] = { + require(name.startsWith(METRICS_PREFIX), + s"internal accumulator name must start with '$METRICS_PREFIX': $name") + name match { + case UPDATED_BLOCK_STATUSES => UpdatedBlockStatusesAccumulatorParam + case shuffleRead.LOCAL_BLOCKS_FETCHED => IntAccumulatorParam + case shuffleRead.REMOTE_BLOCKS_FETCHED => IntAccumulatorParam + case input.READ_METHOD => StringAccumulatorParam + case output.WRITE_METHOD => StringAccumulatorParam + case _ => LongAccumulatorParam } } + /** + * Accumulators for tracking internal metrics. + */ + def create(): Seq[Accumulator[_]] = { + Seq[String]( + EXECUTOR_DESERIALIZE_TIME, + EXECUTOR_RUN_TIME, + RESULT_SIZE, + JVM_GC_TIME, + RESULT_SERIALIZATION_TIME, + MEMORY_BYTES_SPILLED, + DISK_BYTES_SPILLED, + PEAK_EXECUTION_MEMORY, + UPDATED_BLOCK_STATUSES).map(create) ++ + createShuffleReadAccums() ++ + createShuffleWriteAccums() ++ + createInputAccums() ++ + createOutputAccums() ++ + sys.props.get("spark.testing").map(_ => create(TEST_ACCUM)).toSeq + } + + /** + * Accumulators for tracking shuffle read metrics. + */ + def createShuffleReadAccums(): Seq[Accumulator[_]] = { + Seq[String]( + shuffleRead.REMOTE_BLOCKS_FETCHED, + shuffleRead.LOCAL_BLOCKS_FETCHED, + shuffleRead.REMOTE_BYTES_READ, + shuffleRead.LOCAL_BYTES_READ, + shuffleRead.FETCH_WAIT_TIME, + shuffleRead.RECORDS_READ).map(create) + } + + /** + * Accumulators for tracking shuffle write metrics. + */ + def createShuffleWriteAccums(): Seq[Accumulator[_]] = { + Seq[String]( + shuffleWrite.BYTES_WRITTEN, + shuffleWrite.RECORDS_WRITTEN, + shuffleWrite.WRITE_TIME).map(create) + } + + /** + * Accumulators for tracking input metrics. + */ + def createInputAccums(): Seq[Accumulator[_]] = { + Seq[String]( + input.READ_METHOD, + input.BYTES_READ, + input.RECORDS_READ).map(create) + } + + /** + * Accumulators for tracking output metrics. + */ + def createOutputAccums(): Seq[Accumulator[_]] = { + Seq[String]( + output.WRITE_METHOD, + output.BYTES_WRITTEN, + output.RECORDS_WRITTEN).map(create) + } + /** * Accumulators for tracking internal metrics. * @@ -41,18 +187,23 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(sc: SparkContext): Seq[Accumulator[Long]] = { - val internalAccumulators = Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq - internalAccumulators.foreach { accumulator => - sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + def create(sc: SparkContext): Seq[Accumulator[_]] = { + val accums = create() + accums.foreach { accum => + Accumulators.register(accum) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum)) } - internalAccumulators + accums + } + + /** + * Create a new accumulator representing an internal task metric. + */ + private def newMetric[T]( + initialValue: T, + name: String, + param: AccumulatorParam[T]): Accumulator[T] = { + new Accumulator[T](initialValue, param, Some(name), internal = true, countFailedValues = true) } + } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7704abc13409..9f49cf1c4c9b 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -64,7 +64,7 @@ object TaskContext { * An empty task context that does not represent an actual task. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) + new TaskContextImpl(0, 0, 0, 0, null, null) } } @@ -138,7 +138,6 @@ abstract class TaskContext extends Serializable { */ def taskAttemptId(): Long - /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics @@ -161,20 +160,4 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit - /** - * Return the local values of internal accumulators that belong to this task. The key of the Map - * is the accumulator id and the value of the Map is the latest accumulator local value. - */ - private[spark] def collectInternalAccumulators(): Map[Long, Any] - - /** - * Return the local values of accumulators that belong to this task. The key of the Map is the - * accumulator id and the value of the Map is the latest accumulator local value. - */ - private[spark] def collectAccumulators(): Map[Long, Any] - - /** - * Accumulators for tracking internal metrics indexed by the name. - */ - private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 94ff884b742b..27ca46f73d8c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,7 +17,7 @@ package org.apache.spark -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager @@ -32,11 +32,15 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, - internalAccumulators: Seq[Accumulator[Long]], - val taskMetrics: TaskMetrics = TaskMetrics.empty) + initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.create()) extends TaskContext with Logging { + /** + * Metrics associated with this task. + */ + override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators) + // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -91,24 +95,8 @@ private[spark] class TaskContextImpl( override def getMetricsSources(sourceName: String): Seq[Source] = metricsSystem.getSourcesByName(sourceName) - @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] - - private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { - accumulators(a.id) = a - } - - private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { - accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = { + taskMetrics.registerAccumulator(a) } - private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { - accumulators.mapValues(_.localValue).toMap - } - - private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { - // Explicitly register internal accumulators here because these are - // not captured in the task closure and are already deserialized - internalAccumulators.foreach(registerAccumulator) - internalAccumulators.map { a => (a.name.get, a) }.toMap - } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 13241b77bf97..68340cc704da 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -19,8 +19,11 @@ package org.apache.spark import java.io.{ObjectInputStream, ObjectOutputStream} +import scala.util.Try + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -115,22 +118,34 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics], - private val exceptionWrapper: Option[ThrowableSerializationWrapper]) + exceptionWrapper: Option[ThrowableSerializationWrapper], + accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]) extends TaskFailedReason { + @deprecated("use accumUpdates instead", "2.0.0") + val metrics: Option[TaskMetrics] = { + if (accumUpdates.nonEmpty) { + Try(TaskMetrics.fromAccumulatorUpdates(accumUpdates)).toOption + } else { + None + } + } + /** * `preserveCause` is used to keep the exception itself so it is available to the * driver. This may be set to `false` in the event that the exception is not in fact * serializable. */ - private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, - if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + private[spark] def this( + e: Throwable, + accumUpdates: Seq[AccumulableInfo], + preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None, accumUpdates) } - private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e, metrics, preserveCause = true) + private[spark] def this(e: Throwable, accumUpdates: Seq[AccumulableInfo]) { + this(e, accumUpdates, preserveCause = true) } def exception: Option[Throwable] = exceptionWrapper.flatMap { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 8ba3f5e24189..06b5101b1f56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -370,6 +370,14 @@ object SparkHadoopUtil { val SPARK_YARN_CREDS_COUNTER_DELIM = "-" + /** + * Number of records to update input metrics when reading from HadoopRDDs. + * + * Each update is potentially expensive because we need to use reflection to access the + * Hadoop FileSystem API of interest (only available in 2.5), so we should do this sparingly. + */ + private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + def get: SparkHadoopUtil = { // Check each time to support changing to/from YARN val yarnMode = java.lang.Boolean.valueOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 030ae41db4a6..51c000ea5c57 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -31,7 +31,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} +import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -210,7 +210,7 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() var threwException = true - val (value, accumUpdates) = try { + val value = try { val res = task.run( taskAttemptId = taskId, attemptNumber = attemptNumber, @@ -249,10 +249,11 @@ private[spark] class Executor( m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) - m.updateAccumulators() } - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) + // Note: accumulator updates must be collected after TaskMetrics is updated + val accumUpdates = task.collectAccumulatorUpdates() + val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -297,21 +298,25 @@ private[spark] class Executor( // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - val metrics: Option[TaskMetrics] = Option(task).flatMap { task => - task.metrics.map { m => + // Collect latest accumulator values to report back to the driver + val accumulatorUpdates: Seq[AccumulableInfo] = + if (task != null) { + task.metrics.foreach { m => m.setExecutorRunTime(System.currentTimeMillis() - taskStart) m.setJvmGCTime(computeTotalGcTime() - startGCTime) - m.updateAccumulators() - m } + task.collectAccumulatorUpdates(taskFailed = true) + } else { + Seq.empty[AccumulableInfo] } + val serializedTaskEndReason = { try { - ser.serialize(new ExceptionFailure(t, metrics)) + ser.serialize(new ExceptionFailure(t, accumulatorUpdates)) } catch { case _: NotSerializableException => // t is not serializable so just send the stacktrace - ser.serialize(new ExceptionFailure(t, metrics, false)) + ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false)) } } execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) @@ -418,33 +423,21 @@ private[spark] class Executor( /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { - // list of (task id, metrics) to send back to the driver - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + // list of (task id, accumUpdates) to send back to the driver + val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]() val curGCTime = computeTotalGcTime() for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => metrics.mergeShuffleReadMetrics() - metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - metrics.updateAccumulators() - - if (isLocal) { - // JobProgressListener will hold an reference of it during - // onExecutorMetricsUpdate(), then JobProgressListener can not see - // the changes of metrics any more, so make a deep copy of it - val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) - tasksMetrics += ((taskRunner.taskId, copiedMetrics)) - } else { - // It will be copied by serialization - tasksMetrics += ((taskRunner.taskId, metrics)) - } + accumUpdates += ((taskRunner.taskId, metrics.accumulatorUpdates())) } } } - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) try { val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 8f1d7f89a44b..ed9e157ce758 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -17,13 +17,15 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Method by which input data was read. Network means that the data was read over the network + * Method by which input data was read. Network means that the data was read over the network * from a remote block manager (which may have stored the data on-disk or in-memory). + * Operations are not thread-safe. */ @DeveloperApi object DataReadMethod extends Enumeration with Serializable { @@ -34,44 +36,75 @@ object DataReadMethod extends Enumeration with Serializable { /** * :: DeveloperApi :: - * Metrics about reading input data. + * A collection of accumulators that represents metrics about reading data from external systems. */ @DeveloperApi -case class InputMetrics(readMethod: DataReadMethod.Value) { +class InputMetrics private ( + _bytesRead: Accumulator[Long], + _recordsRead: Accumulator[Long], + _readMethod: Accumulator[String]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.RECORDS_READ), + TaskMetrics.getAccum[String](accumMap, InternalAccumulator.input.READ_METHOD)) + } /** - * This is volatile so that it is visible to the updater thread. + * Create a new [[InputMetrics]] that is not associated with any particular task. + * + * This mainly exists because of SPARK-5225, where we are forced to use a dummy [[InputMetrics]] + * because we want to ignore metrics from a second read method. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerInputMetrics]]. */ - @volatile @transient var bytesReadCallback: Option[() => Long] = None + private[executor] def this() { + this(InternalAccumulator.createInputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) + } /** - * Total bytes read. + * Total number of bytes read. */ - private var _bytesRead: Long = _ - def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long): Unit = _bytesRead += bytes + def bytesRead: Long = _bytesRead.localValue /** - * Total records read. + * Total number of records read. */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def recordsRead: Long = _recordsRead.localValue /** - * Invoke the bytesReadCallback and mutate bytesRead. + * The source from which this task reads its input. */ - def updateBytesRead() { - bytesReadCallback.foreach { c => - _bytesRead = c() - } + def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) + + @deprecated("incrementing input metrics is for internal use only", "2.0.0") + def incBytesRead(v: Long): Unit = _bytesRead.add(v) + @deprecated("incrementing input metrics is for internal use only", "2.0.0") + def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) + private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = + _readMethod.setValue(v.toString) + +} + +/** + * Deprecated methods to preserve case class matching behavior before Spark 2.0. + */ +object InputMetrics { + + @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") + def apply(readMethod: DataReadMethod.Value): InputMetrics = { + val im = new InputMetrics + im.setReadMethod(readMethod) + im } - /** - * Register a function that can be called to get up-to-date information on how many bytes the task - * has read from an input source. - */ - def setBytesReadCallback(f: Option[() => Long]) { - bytesReadCallback = f + @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") + def unapply(input: InputMetrics): Option[DataReadMethod.Value] = { + Some(input.readMethod) } } diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index ad132d004cde..0b37d559c746 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -17,12 +17,14 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * Method by which output data was written. + * Operations are not thread-safe. */ @DeveloperApi object DataWriteMethod extends Enumeration with Serializable { @@ -33,21 +35,70 @@ object DataWriteMethod extends Enumeration with Serializable { /** * :: DeveloperApi :: - * Metrics about writing output data. + * A collection of accumulators that represents metrics about writing data to external systems. */ @DeveloperApi -case class OutputMetrics(writeMethod: DataWriteMethod.Value) { +class OutputMetrics private ( + _bytesWritten: Accumulator[Long], + _recordsWritten: Accumulator[Long], + _writeMethod: Accumulator[String]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.BYTES_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.RECORDS_WRITTEN), + TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD)) + } + + /** + * Create a new [[OutputMetrics]] that is not associated with any particular task. + * + * This is only used for preserving matching behavior on [[OutputMetrics]], which used to be + * a case class before Spark 2.0. Once we remove support for matching on [[OutputMetrics]] + * we can remove this constructor as well. + */ + private[executor] def this() { + this(InternalAccumulator.createOutputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) + } + + /** + * Total number of bytes written. + */ + def bytesWritten: Long = _bytesWritten.localValue + /** - * Total bytes written + * Total number of records written. */ - private var _bytesWritten: Long = _ - def bytesWritten: Long = _bytesWritten - private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value + def recordsWritten: Long = _recordsWritten.localValue /** - * Total records written + * The source to which this task writes its output. */ - private var _recordsWritten: Long = 0L - def recordsWritten: Long = _recordsWritten - private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value + def writeMethod: DataWriteMethod.Value = DataWriteMethod.withName(_writeMethod.localValue) + + private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v) + private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v) + private[spark] def setWriteMethod(v: DataWriteMethod.Value): Unit = + _writeMethod.setValue(v.toString) + +} + +/** + * Deprecated methods to preserve case class matching behavior before Spark 2.0. + */ +object OutputMetrics { + + @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") + def apply(writeMethod: DataWriteMethod.Value): OutputMetrics = { + val om = new OutputMetrics + om.setWriteMethod(writeMethod) + om + } + + @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") + def unapply(output: OutputMetrics): Option[DataWriteMethod.Value] = { + Some(output.writeMethod) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index e985b35ace62..50bb645d974a 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -17,71 +17,103 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Metrics pertaining to shuffle data read in a given task. + * A collection of accumulators that represent metrics about reading shuffle data. + * Operations are not thread-safe. */ @DeveloperApi -class ShuffleReadMetrics extends Serializable { +class ShuffleReadMetrics private ( + _remoteBlocksFetched: Accumulator[Int], + _localBlocksFetched: Accumulator[Int], + _remoteBytesRead: Accumulator[Long], + _localBytesRead: Accumulator[Long], + _fetchWaitTime: Accumulator[Long], + _recordsRead: Accumulator[Long]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.REMOTE_BLOCKS_FETCHED), + TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.LOCAL_BLOCKS_FETCHED), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.REMOTE_BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.LOCAL_BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.FETCH_WAIT_TIME), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.RECORDS_READ)) + } + /** - * Number of remote blocks fetched in this shuffle by this task + * Create a new [[ShuffleReadMetrics]] that is not associated with any particular task. + * + * This mainly exists for legacy reasons, because we use dummy [[ShuffleReadMetrics]] in + * many places only to merge their values together later. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerTempShuffleReadMetrics]] followed by + * [[TaskMetrics.mergeShuffleReadMetrics]]. */ - private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched: Int = _remoteBlocksFetched - private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + private[spark] def this() { + this(InternalAccumulator.createShuffleReadAccums().map { a => (a.name.get, a) }.toMap) + } /** - * Number of local blocks fetched in this shuffle by this task + * Number of remote blocks fetched in this shuffle by this task. */ - private var _localBlocksFetched: Int = _ - def localBlocksFetched: Int = _localBlocksFetched - private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value + def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue /** - * Time the task spent waiting for remote shuffle blocks. This only includes the time - * blocking on shuffle input data. For instance if block B is being fetched while the task is - * still not finished processing block A, it is not considered to be blocking on block B. + * Number of local blocks fetched in this shuffle by this task. */ - private var _fetchWaitTime: Long = _ - def fetchWaitTime: Long = _fetchWaitTime - private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value - private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value + def localBlocksFetched: Int = _localBlocksFetched.localValue /** - * Total number of remote bytes read from the shuffle by this task + * Total number of remote bytes read from the shuffle by this task. */ - private var _remoteBytesRead: Long = _ - def remoteBytesRead: Long = _remoteBytesRead - private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value - private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + def remoteBytesRead: Long = _remoteBytesRead.localValue /** * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ - private var _localBytesRead: Long = _ - def localBytesRead: Long = _localBytesRead - private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + def localBytesRead: Long = _localBytesRead.localValue /** - * Total bytes fetched in the shuffle by this task (both remote and local). + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. + */ + def fetchWaitTime: Long = _fetchWaitTime.localValue + + /** + * Total number of records read from the shuffle by this task. */ - def totalBytesRead: Long = _remoteBytesRead + _localBytesRead + def recordsRead: Long = _recordsRead.localValue /** - * Number of blocks fetched in this shuffle by this task (remote or local) + * Total bytes fetched in the shuffle by this task (both remote and local). */ - def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched + def totalBytesRead: Long = remoteBytesRead + localBytesRead /** - * Total number of records read from the shuffle by this task + * Number of blocks fetched in this shuffle by this task (remote or local). */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - private[spark] def incRecordsRead(value: Long) = _recordsRead += value - private[spark] def decRecordsRead(value: Long) = _recordsRead -= value + def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched + + private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v) + private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v) + private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v) + private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v) + private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v) + private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + + private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v) + private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v) + private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v) + private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v) + private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) + private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) + } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 24795f860087..c7aaabb561bb 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -17,40 +17,66 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Metrics pertaining to shuffle data written in a given task. + * A collection of accumulators that represent metrics about writing shuffle data. + * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics extends Serializable { +class ShuffleWriteMetrics private ( + _bytesWritten: Accumulator[Long], + _recordsWritten: Accumulator[Long], + _writeTime: Accumulator[Long]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.BYTES_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.RECORDS_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.WRITE_TIME)) + } /** - * Number of bytes written for the shuffle by this task + * Create a new [[ShuffleWriteMetrics]] that is not associated with any particular task. + * + * This mainly exists for legacy reasons, because we use dummy [[ShuffleWriteMetrics]] in + * many places only to merge their values together later. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerShuffleWriteMetrics]]. */ - @volatile private var _bytesWritten: Long = _ - def bytesWritten: Long = _bytesWritten - private[spark] def incBytesWritten(value: Long) = _bytesWritten += value - private[spark] def decBytesWritten(value: Long) = _bytesWritten -= value + private[spark] def this() { + this(InternalAccumulator.createShuffleWriteAccums().map { a => (a.name.get, a) }.toMap) + } /** - * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds + * Number of bytes written for the shuffle by this task. */ - @volatile private var _writeTime: Long = _ - def writeTime: Long = _writeTime - private[spark] def incWriteTime(value: Long) = _writeTime += value - private[spark] def decWriteTime(value: Long) = _writeTime -= value + def bytesWritten: Long = _bytesWritten.localValue /** - * Total number of records written to the shuffle by this task + * Total number of records written to the shuffle by this task. */ - @volatile private var _recordsWritten: Long = _ - def recordsWritten: Long = _recordsWritten - private[spark] def incRecordsWritten(value: Long) = _recordsWritten += value - private[spark] def decRecordsWritten(value: Long) = _recordsWritten -= value - private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value + def recordsWritten: Long = _recordsWritten.localValue + + /** + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds. + */ + def writeTime: Long = _writeTime.localValue + + private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] def decBytesWritten(v: Long): Unit = { + _bytesWritten.setValue(bytesWritten - v) + } + private[spark] def decRecordsWritten(v: Long): Unit = { + _recordsWritten.setValue(recordsWritten - v) + } // Legacy methods for backward compatibility. // TODO: remove these once we make this class private. diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 32ef5a9b5606..8d10bf588ef1 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,90 +17,161 @@ package org.apache.spark.executor -import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentHashMap - +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.DataReadMethod.DataReadMethod +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.util.Utils /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. * - * This class is used to house metrics both for in-progress and completed tasks. In executors, - * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread - * reads it to send in-progress metrics, and the task thread reads it to send metrics along with - * the completed task. + * This class is wrapper around a collection of internal accumulators that represent metrics + * associated with a task. The local values of these accumulators are sent from the executor + * to the driver when the task completes. These values are then merged into the corresponding + * accumulator previously registered on the driver. + * + * The accumulator updates are also sent to the driver periodically (on executor heartbeat) + * and when the task failed with an exception. The [[TaskMetrics]] object itself should never + * be sent to the driver. * - * So, when adding new fields, take into consideration that the whole object can be serialized for - * shipping off at any time to consumers of the SparkListener interface. + * @param initialAccums the initial set of accumulators that this [[TaskMetrics]] depends on. + * Each accumulator in this initial set must be uniquely named and marked + * as internal. Additional accumulators registered later need not satisfy + * these requirements. */ @DeveloperApi -class TaskMetrics extends Serializable { +class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { + + import InternalAccumulator._ + + // Needed for Java tests + def this() { + this(InternalAccumulator.create()) + } + + /** + * All accumulators registered with this task. + */ + private val accums = new ArrayBuffer[Accumulable[_, _]] + accums ++= initialAccums + + /** + * A map for quickly accessing the initial set of accumulators by name. + */ + private val initialAccumsMap: Map[String, Accumulator[_]] = { + val map = new mutable.HashMap[String, Accumulator[_]] + initialAccums.foreach { a => + val name = a.name.getOrElse { + throw new IllegalArgumentException( + "initial accumulators passed to TaskMetrics must be named") + } + require(a.isInternal, + s"initial accumulator '$name' passed to TaskMetrics must be marked as internal") + require(!map.contains(name), + s"detected duplicate accumulator name '$name' when constructing TaskMetrics") + map(name) = a + } + map.toMap + } + + // Each metric is internally represented as an accumulator + private val _executorDeserializeTime = getAccum(EXECUTOR_DESERIALIZE_TIME) + private val _executorRunTime = getAccum(EXECUTOR_RUN_TIME) + private val _resultSize = getAccum(RESULT_SIZE) + private val _jvmGCTime = getAccum(JVM_GC_TIME) + private val _resultSerializationTime = getAccum(RESULT_SERIALIZATION_TIME) + private val _memoryBytesSpilled = getAccum(MEMORY_BYTES_SPILLED) + private val _diskBytesSpilled = getAccum(DISK_BYTES_SPILLED) + private val _peakExecutionMemory = getAccum(PEAK_EXECUTION_MEMORY) + private val _updatedBlockStatuses = + TaskMetrics.getAccum[Seq[(BlockId, BlockStatus)]](initialAccumsMap, UPDATED_BLOCK_STATUSES) + /** - * Host's name the task runs on + * Time taken on the executor to deserialize this task. */ - private var _hostname: String = _ - def hostname: String = _hostname - private[spark] def setHostname(value: String) = _hostname = value + def executorDeserializeTime: Long = _executorDeserializeTime.localValue /** - * Time taken on the executor to deserialize this task + * Time the executor spends actually running the task (including fetching shuffle data). */ - private var _executorDeserializeTime: Long = _ - def executorDeserializeTime: Long = _executorDeserializeTime - private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value + def executorRunTime: Long = _executorRunTime.localValue + /** + * The number of bytes this task transmitted back to the driver as the TaskResult. + */ + def resultSize: Long = _resultSize.localValue /** - * Time the executor spends actually running the task (including fetching shuffle data) + * Amount of time the JVM spent in garbage collection while executing this task. */ - private var _executorRunTime: Long = _ - def executorRunTime: Long = _executorRunTime - private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value + def jvmGCTime: Long = _jvmGCTime.localValue /** - * The number of bytes this task transmitted back to the driver as the TaskResult + * Amount of time spent serializing the task result. */ - private var _resultSize: Long = _ - def resultSize: Long = _resultSize - private[spark] def setResultSize(value: Long) = _resultSize = value + def resultSerializationTime: Long = _resultSerializationTime.localValue + /** + * The number of in-memory bytes spilled by this task. + */ + def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue /** - * Amount of time the JVM spent in garbage collection while executing this task + * The number of on-disk bytes spilled by this task. */ - private var _jvmGCTime: Long = _ - def jvmGCTime: Long = _jvmGCTime - private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value + def diskBytesSpilled: Long = _diskBytesSpilled.localValue /** - * Amount of time spent serializing the task result + * Peak memory used by internal data structures created during shuffles, aggregations and + * joins. The value of this accumulator should be approximately the sum of the peak sizes + * across all such data structures created in this task. For SQL jobs, this only tracks all + * unsafe operators and ExternalSort. */ - private var _resultSerializationTime: Long = _ - def resultSerializationTime: Long = _resultSerializationTime - private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value + def peakExecutionMemory: Long = _peakExecutionMemory.localValue /** - * The number of in-memory bytes spilled by this task + * Storage statuses of any blocks that have been updated as a result of this task. */ - private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled: Long = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue + + @deprecated("use updatedBlockStatuses instead", "2.0.0") + def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { + if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None + } + + // Setters and increment-ers + private[spark] def setExecutorDeserializeTime(v: Long): Unit = + _executorDeserializeTime.setValue(v) + private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v) + private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v) + private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v) + private[spark] def setResultSerializationTime(v: Long): Unit = + _resultSerializationTime.setValue(v) + private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v) + private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v) + private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) + private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.add(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v) /** - * The number of on-disk bytes spilled by this task + * Get a Long accumulator from the given map by name, assuming it exists. + * Note: this only searches the initial set of accumulators passed into the constructor. */ - private var _diskBytesSpilled: Long = _ - def diskBytesSpilled: Long = _diskBytesSpilled - private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + private[spark] def getAccum(name: String): Accumulator[Long] = { + TaskMetrics.getAccum[Long](initialAccumsMap, name) + } + + + /* ========================== * + | INPUT METRICS | + * ========================== */ private var _inputMetrics: Option[InputMetrics] = None @@ -116,7 +187,8 @@ class TaskMetrics extends Serializable { private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = { synchronized { val metrics = _inputMetrics.getOrElse { - val metrics = new InputMetrics(readMethod) + val metrics = new InputMetrics(initialAccumsMap) + metrics.setReadMethod(readMethod) _inputMetrics = Some(metrics) metrics } @@ -128,18 +200,17 @@ class TaskMetrics extends Serializable { if (metrics.readMethod == readMethod) { metrics } else { - new InputMetrics(readMethod) + val m = new InputMetrics + m.setReadMethod(readMethod) + m } } } - /** - * This should only be used when recreating TaskMetrics, not when updating input metrics in - * executors - */ - private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) { - _inputMetrics = inputMetrics - } + + /* ============================ * + | OUTPUT METRICS | + * ============================ */ private var _outputMetrics: Option[OutputMetrics] = None @@ -149,23 +220,24 @@ class TaskMetrics extends Serializable { */ def outputMetrics: Option[OutputMetrics] = _outputMetrics - @deprecated("setting OutputMetrics is for internal use only", "2.0.0") - def outputMetrics_=(om: Option[OutputMetrics]): Unit = { - _outputMetrics = om - } - /** * Get or create a new [[OutputMetrics]] associated with this task. */ private[spark] def registerOutputMetrics( writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized { _outputMetrics.getOrElse { - val metrics = new OutputMetrics(writeMethod) + val metrics = new OutputMetrics(initialAccumsMap) + metrics.setWriteMethod(writeMethod) _outputMetrics = Some(metrics) metrics } } + + /* ================================== * + | SHUFFLE READ METRICS | + * ================================== */ + private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None /** @@ -174,21 +246,13 @@ class TaskMetrics extends Serializable { */ def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics - /** - * This should only be used when recreating TaskMetrics, not when updating read metrics in - * executors. - */ - private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { - _shuffleReadMetrics = shuffleReadMetrics - } - /** * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency. * * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization * issues from readers in different threads, in-progress tasks use a [[ShuffleReadMetrics]] for * each dependency and merge these metrics before reporting them to the driver. - */ + */ @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[ShuffleReadMetrics] /** @@ -210,19 +274,21 @@ class TaskMetrics extends Serializable { */ private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { if (tempShuffleReadMetrics.nonEmpty) { - val merged = new ShuffleReadMetrics - for (depMetrics <- tempShuffleReadMetrics) { - merged.incFetchWaitTime(depMetrics.fetchWaitTime) - merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) - merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) - merged.incRemoteBytesRead(depMetrics.remoteBytesRead) - merged.incLocalBytesRead(depMetrics.localBytesRead) - merged.incRecordsRead(depMetrics.recordsRead) - } - _shuffleReadMetrics = Some(merged) + val metrics = new ShuffleReadMetrics(initialAccumsMap) + metrics.setRemoteBlocksFetched(tempShuffleReadMetrics.map(_.remoteBlocksFetched).sum) + metrics.setLocalBlocksFetched(tempShuffleReadMetrics.map(_.localBlocksFetched).sum) + metrics.setFetchWaitTime(tempShuffleReadMetrics.map(_.fetchWaitTime).sum) + metrics.setRemoteBytesRead(tempShuffleReadMetrics.map(_.remoteBytesRead).sum) + metrics.setLocalBytesRead(tempShuffleReadMetrics.map(_.localBytesRead).sum) + metrics.setRecordsRead(tempShuffleReadMetrics.map(_.recordsRead).sum) + _shuffleReadMetrics = Some(metrics) } } + /* =================================== * + | SHUFFLE WRITE METRICS | + * =================================== */ + private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None /** @@ -230,86 +296,120 @@ class TaskMetrics extends Serializable { */ def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics - @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0") - def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = { - _shuffleWriteMetrics = swm - } - /** * Get or create a new [[ShuffleWriteMetrics]] associated with this task. */ private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized { _shuffleWriteMetrics.getOrElse { - val metrics = new ShuffleWriteMetrics + val metrics = new ShuffleWriteMetrics(initialAccumsMap) _shuffleWriteMetrics = Some(metrics) metrics } } - private var _updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = - Seq.empty[(BlockId, BlockStatus)] - /** - * Storage statuses of any blocks that have been updated as a result of this task. - */ - def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses + /* ========================== * + | OTHER THINGS | + * ========================== */ - @deprecated("setting updated blocks is for internal use only", "2.0.0") - def updatedBlocks_=(ub: Option[Seq[(BlockId, BlockStatus)]]): Unit = { - _updatedBlockStatuses = ub.getOrElse(Seq.empty[(BlockId, BlockStatus)]) + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = { + accums += a } - private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = { - _updatedBlockStatuses ++= v + /** + * Return the latest updates of accumulators in this task. + * + * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]] + * field is always empty, since this represents the partial updates recorded in this task, + * not the aggregated value across multiple tasks. + */ + def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a => + new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues) } - private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = { - _updatedBlockStatuses = v + // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. + // If so, initialize all relevant metrics classes so listeners can access them downstream. + { + var (hasShuffleRead, hasShuffleWrite, hasInput, hasOutput) = (false, false, false, false) + initialAccums + .filter { a => a.localValue != a.zero } + .foreach { a => + a.name.get match { + case sr if sr.startsWith(SHUFFLE_READ_METRICS_PREFIX) => hasShuffleRead = true + case sw if sw.startsWith(SHUFFLE_WRITE_METRICS_PREFIX) => hasShuffleWrite = true + case in if in.startsWith(INPUT_METRICS_PREFIX) => hasInput = true + case out if out.startsWith(OUTPUT_METRICS_PREFIX) => hasOutput = true + case _ => + } + } + if (hasShuffleRead) { _shuffleReadMetrics = Some(new ShuffleReadMetrics(initialAccumsMap)) } + if (hasShuffleWrite) { _shuffleWriteMetrics = Some(new ShuffleWriteMetrics(initialAccumsMap)) } + if (hasInput) { _inputMetrics = Some(new InputMetrics(initialAccumsMap)) } + if (hasOutput) { _outputMetrics = Some(new OutputMetrics(initialAccumsMap)) } } - @deprecated("use updatedBlockStatuses instead", "2.0.0") - def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { - if (_updatedBlockStatuses.nonEmpty) Some(_updatedBlockStatuses) else None - } +} - private[spark] def updateInputMetrics(): Unit = synchronized { - inputMetrics.foreach(_.updateBytesRead()) - } +private[spark] object TaskMetrics extends Logging { - @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() - // Get the hostname from cached data, since hostname is the order of number of nodes in - // cluster, so using cached hostname will decrease the object number and alleviate the GC - // overhead. - _hostname = TaskMetrics.getCachedHostName(_hostname) - } - - private var _accumulatorUpdates: Map[Long, Any] = Map.empty - @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + def empty: TaskMetrics = new TaskMetrics - private[spark] def updateAccumulators(): Unit = synchronized { - _accumulatorUpdates = _accumulatorsUpdater() + /** + * Get an accumulator from the given map by name, assuming it exists. + */ + def getAccum[T](accumMap: Map[String, Accumulator[_]], name: String): Accumulator[T] = { + require(accumMap.contains(name), s"metric '$name' is missing") + val accum = accumMap(name) + try { + // Note: we can't do pattern matching here because types are erased by compile time + accum.asInstanceOf[Accumulator[T]] + } catch { + case e: ClassCastException => + throw new SparkException(s"accumulator $name was of unexpected type", e) + } } /** - * Return the latest updates of accumulators in this task. + * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only. + * + * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we + * need the latter to post task end events to listeners, so we need to reconstruct the metrics + * on the driver. + * + * This assumes the provided updates contain the initial set of accumulators representing + * internal task level metrics. */ - def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates - - private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { - _accumulatorsUpdater = accumulatorsUpdater + def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = { + // Initial accumulators are passed into the TaskMetrics constructor first because these + // are required to be uniquely named. The rest of the accumulators from this task are + // registered later because they need not satisfy this requirement. + val (initialAccumInfos, otherAccumInfos) = accumUpdates + .filter { info => info.update.isDefined } + .partition { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) } + val initialAccums = initialAccumInfos.map { info => + val accum = InternalAccumulator.create(info.name.get) + accum.setValueAny(info.update.get) + accum + } + // We don't know the types of the rest of the accumulators, so we try to find the same ones + // that were previously registered here on the driver and make copies of them. It is important + // that we copy the accumulators here since they are used across many tasks and we want to + // maintain a snapshot of their local task values when we post them to listeners downstream. + val otherAccums = otherAccumInfos.flatMap { info => + val id = info.id + val acc = Accumulators.get(id).map { a => + val newAcc = a.copy() + newAcc.setValueAny(info.update.get) + newAcc + } + if (acc.isEmpty) { + logWarning(s"encountered unregistered accumulator $id when reconstructing task metrics.") + } + acc + } + val metrics = new TaskMetrics(initialAccums) + otherAccums.foreach(metrics.registerAccumulator) + metrics } -} - -private[spark] object TaskMetrics { - private val hostNameCache = new ConcurrentHashMap[String, String]() - - def empty: TaskMetrics = new TaskMetrics - - def getCachedHostName(host: String): String = { - val canonicalHost = hostNameCache.putIfAbsent(host, host) - if (canonicalHost != null) canonicalHost else host - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 3587e7eb1afa..d9b0824b38ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -153,8 +153,7 @@ class CoGroupedRDD[K: ClassTag]( } context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index a79ab86d4922..3204e6adceca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -212,6 +212,8 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() + // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) // Sets the thread local variable for the file's name @@ -222,14 +224,17 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) @@ -252,6 +257,9 @@ class HadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } (key, value) } @@ -272,8 +280,8 @@ class HadoopRDD[K, V]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 5cc9c81cc674..4d2816e335fe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -133,14 +133,17 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) val format = inputFormatClass.newInstance format match { @@ -182,6 +185,9 @@ class NewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } (reader.getCurrentKey, reader.getCurrentValue) } @@ -201,8 +207,8 @@ class NewHadoopRDD[K, V]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 146cfb9ba803..9d45fff9213c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -19,47 +19,58 @@ package org.apache.spark.scheduler import org.apache.spark.annotation.DeveloperApi + /** * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + * + * Note: once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. + * + * @param id accumulator ID + * @param name accumulator name + * @param update partial value from a task, may be None if used on driver to describe a stage + * @param value total accumulated value so far, maybe None if used on executors to describe a task + * @param internal whether this accumulator was internal + * @param countFailedValues whether to count this accumulator's partial value if the task failed */ @DeveloperApi -class AccumulableInfo private[spark] ( - val id: Long, - val name: String, - val update: Option[String], // represents a partial update within a task - val value: String, - val internal: Boolean) { - - override def equals(other: Any): Boolean = other match { - case acc: AccumulableInfo => - this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value && - this.internal == acc.internal - case _ => false - } +case class AccumulableInfo private[spark] ( + id: Long, + name: Option[String], + update: Option[Any], // represents a partial update within a task + value: Option[Any], + private[spark] val internal: Boolean, + private[spark] val countFailedValues: Boolean) - override def hashCode(): Int = { - val state = Seq(id, name, update, value, internal) - state.map(_.hashCode).reduceLeft(31 * _ + _) - } -} +/** + * A collection of deprecated constructors. This will be removed soon. + */ object AccumulableInfo { + + @deprecated("do not create AccumulableInfo", "2.0.0") def apply( id: Long, name: String, update: Option[String], value: String, internal: Boolean): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, internal) + new AccumulableInfo( + id, Option(name), update, Option(value), internal, countFailedValues = false) } + @deprecated("do not create AccumulableInfo", "2.0.0") def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, internal = false) + new AccumulableInfo( + id, Option(name), update, Option(value), internal = false, countFailedValues = false) } + @deprecated("do not create AccumulableInfo", "2.0.0") def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value, internal = false) + new AccumulableInfo( + id, Option(name), None, Option(value), internal = false, countFailedValues = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6b01a10fc136..897479b50010 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -208,11 +208,10 @@ class DAGScheduler( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics): Unit = { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo): Unit = { eventProcessLoop.post( - CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + CompletionEvent(task, reason, result, accumUpdates, taskInfo)) } /** @@ -222,9 +221,10 @@ class DAGScheduler( */ def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } @@ -1074,39 +1074,43 @@ class DAGScheduler( } } - /** Merge updates from a task to our local accumulator values */ + /** + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ private def updateAccumulators(event: CompletionEvent): Unit = { val task = event.task val stage = stageIdToStage(task.stageId) - if (event.accumUpdates != null) { - try { - Accumulators.add(event.accumUpdates) - - event.accumUpdates.foreach { case (id, partialValue) => - // In this instance, although the reference in Accumulators.originals is a WeakRef, - // it's guaranteed to exist since the event.accumUpdates Map exists - - val acc = Accumulators.originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] - case None => throw new NullPointerException("Non-existent reference to Accumulator") - } - - // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name.get - val value = s"${acc.value}" - stage.latestInfo.accumulables(id) = - new AccumulableInfo(id, name, None, value, acc.isInternal) - event.taskInfo.accumulables += - new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) - } + try { + event.accumUpdates.foreach { ainfo => + assert(ainfo.update.isDefined, "accumulator from task should have a partial value") + val id = ainfo.id + val partialValue = ainfo.update.get + // Find the corresponding accumulator on the driver and update it + val acc: Accumulable[Any, Any] = Accumulators.get(id) match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => + throw new SparkException(s"attempted to access non-existent accumulator $id") + } + acc ++= partialValue + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name + stage.latestInfo.accumulables(id) = new AccumulableInfo( + id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues) + event.taskInfo.accumulables += new AccumulableInfo( + id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues) } - } catch { - // If we see an exception during accumulator update, just log the - // error and move on. - case e: Exception => - logError(s"Failed to update accumulators for $task", e) } + } catch { + case NonFatal(e) => + logError(s"Failed to update accumulators for task ${task.partitionId}", e) } } @@ -1116,6 +1120,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task + val taskId = event.taskInfo.id val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) @@ -1125,12 +1130,26 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // The success case is dealt with separately below, since we need to compute accumulator - // updates before posting. + // Reconstruct task metrics. Note: this may be null if the task has failed. + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulatorUpdates(event.accumUpdates) + } catch { + case NonFatal(e) => + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + // The success case is dealt with separately below. + // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here. if (event.reason != Success) { val attemptId = task.stageAttemptId - listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, - event.taskInfo, event.taskMetrics)) + listenerBus.post(SparkListenerTaskEnd( + stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics)) } if (!stageIdToStage.contains(task.stageId)) { @@ -1142,7 +1161,7 @@ class DAGScheduler( event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, - event.reason, event.taskInfo, event.taskMetrics)) + event.reason, event.taskInfo, taskMetrics)) stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => @@ -1291,7 +1310,8 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Do nothing here, left up to the TaskScheduler to decide how to handle user failures + // Tasks failed with exceptions might still have accumulator updates. + updateAccumulators(event) case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. @@ -1637,7 +1657,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) - case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => + case completion: CompletionEvent => dagScheduler.handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason, exception) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index dda3b6cc7f96..d5cd2da7a10d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,9 +73,8 @@ private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 6590cf6ffd24..885f70e89fbf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each * partition of the given RDD. Once deserialized, the type should be * (RDD[T], (TaskContext, Iterator[T]) => U). @@ -37,6 +38,9 @@ import org.apache.spark.rdd.RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). + * @param _initialAccums initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] class ResultTask[T, U]( stageId: Int, @@ -45,8 +49,8 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, - internalAccumulators: Seq[Accumulator[Long]]) - extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.create()) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ea97ef0e746d..89207dd175ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,10 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling + * @param _initialAccums initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -44,8 +48,8 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - internalAccumulators: Seq[Accumulator[Long]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + _initialAccums: Seq[Accumulator[_]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 6c6883d703be..ed3adbd81c28 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties +import javax.annotation.Nullable import scala.collection.Map import scala.collection.mutable @@ -60,7 +61,7 @@ case class SparkListenerTaskEnd( taskType: String, reason: TaskEndReason, taskInfo: TaskInfo, - taskMetrics: TaskMetrics) + @Nullable taskMetrics: TaskMetrics) extends SparkListenerEvent @DeveloperApi @@ -111,12 +112,12 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends /** * Periodic updates from executors. * @param execId executor id - * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics) + * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates) */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - taskMetrics: Seq[(Long, Int, Int, TaskMetrics)]) + accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])]) extends SparkListenerEvent @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 7ea24a217bd3..c1c8b47128f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -74,10 +74,10 @@ private[scheduler] abstract class Stage( val name: String = callSite.shortForm val details: String = callSite.longForm - private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty /** Internal accumulators shared across all tasks in this stage. */ - def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators /** * Re-initialize the internal accumulators associated with this stage. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index fca57928eca1..a49f3716e270 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import scala.collection.mutable.HashMap @@ -41,32 +41,29 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * and divides the task output to multiple buckets (based on the task's partitioner). * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD + * @param initialAccumulators initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { /** - * The key of the Map is the accumulator id and the value of the Map is the latest accumulator - * local value. - */ - type AccumulatorUpdates = Map[Long, Any] - - /** - * Called by [[Executor]] to run this task. + * Called by [[org.apache.spark.executor.Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task along with updates of Accumulators. */ final def run( - taskAttemptId: Long, - attemptNumber: Int, - metricsSystem: MetricsSystem) - : (T, AccumulatorUpdates) = { + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem): T = { context = new TaskContextImpl( stageId, partitionId, @@ -74,16 +71,14 @@ private[spark] abstract class Task[T]( attemptNumber, taskMemoryManager, metricsSystem, - internalAccumulators) + initialAccumulators) TaskContext.setTaskContext(context) - context.taskMetrics.setHostname(Utils.localHostName()) - context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - (runTask(context), context.collectAccumulators()) + runTask(context) } finally { context.markTaskCompleted() try { @@ -140,6 +135,18 @@ private[spark] abstract class Task[T]( */ def executorDeserializeTime: Long = _executorDeserializeTime + /** + * Collect the latest values of accumulators used in this task. If the task failed, + * filter out the accumulators whose values should not be included on failures. + */ + def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = { + if (context != null) { + context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues } + } else { + Seq.empty[AccumulableInfo] + } + } + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index b82c7f3fa54f..03135e63d755 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,11 +20,9 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.Map -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockId import org.apache.spark.util.Utils @@ -36,31 +34,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ -private[spark] -class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any], - var metrics: TaskMetrics) +private[spark] class DirectTaskResult[T]( + var valueBytes: ByteBuffer, + var accumUpdates: Seq[AccumulableInfo]) extends TaskResult[T] with Externalizable { private var valueObjectDeserialized = false private var valueObject: T = _ - def this() = this(null.asInstanceOf[ByteBuffer], null, null) + def this() = this(null.asInstanceOf[ByteBuffer], null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - - out.writeInt(valueBytes.remaining); + out.writeInt(valueBytes.remaining) Utils.writeByteBuffer(valueBytes, out) - out.writeInt(accumUpdates.size) - for ((key, value) <- accumUpdates) { - out.writeLong(key) - out.writeObject(value) - } - out.writeObject(metrics) + accumUpdates.foreach(out.writeObject) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val blen = in.readInt() val byteVal = new Array[Byte](blen) in.readFully(byteVal) @@ -70,13 +61,12 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - val _accumUpdates = mutable.Map[Long, Any]() + val _accumUpdates = new ArrayBuffer[AccumulableInfo] for (i <- 0 until numUpdates) { - _accumUpdates(in.readLong()) = in.readObject() + _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo] } accumUpdates = _accumUpdates } - metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index f4965994d827..c94c4f55e9ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.concurrent.RejectedExecutionException +import java.util.concurrent.{ExecutorService, RejectedExecutionException} import scala.language.existentials import scala.util.control.NonFatal @@ -35,9 +35,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( - THREADS, "task-result-getter") + // Exposed for testing. + protected val getTaskResultExecutor: ExecutorService = + ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter") + + // Exposed for testing. protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { sparkEnv.closureSerializer.newInstance() @@ -45,7 +48,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } def enqueueSuccessfulTask( - taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, + tid: Long, + serializedData: ByteBuffer): Unit = { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { @@ -82,7 +87,19 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul (deserializedResult, size) } - result.metrics.setResultSize(size) + // Set the task result size in the accumulator updates received from the executors. + // We need to do this here on the driver because if we did this on the executors then + // we would have to serialize the result again after updating the size. + result.accumUpdates = result.accumUpdates.map { a => + if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { + assert(a.update == Some(0L), + "task result size should not have been set on the executors") + a.copy(update = Some(size.toLong)) + } else { + a + } + } + scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 7c0b007db708..fccd6e069934 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -65,8 +65,10 @@ private[spark] trait TaskScheduler { * alive. Return true if the driver knows about the given block manager. Otherwise, return false, * indicating that the block manager should re-register. */ - def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean + def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean /** * Get an application ID associated with the job. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6e3ef0e54f0f..29341dfe3043 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,7 +30,6 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -380,17 +379,17 @@ private[spark] class TaskSchedulerImpl( */ override def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + accumUpdates: Array[(Long, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - - val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { - taskMetrics.flatMap { case (id, metrics) => + // (taskId, stageId, stageAttemptId, accumUpdates) + val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized { + accumUpdates.flatMap { case (id, updates) => taskIdToTaskSetManager.get(id).map { taskSetMgr => - (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates) } } } - dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId) } def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index aa39b59d8cce..cf97877476d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -621,8 +621,7 @@ private[spark] class TaskSetManager( // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. // Note: "result.value()" only deserializes the value when it's called at the first time, so // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) if (!successful(index)) { tasksSuccessful += 1 logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( @@ -653,8 +652,7 @@ private[spark] class TaskSetManager( info.markFailed() val index = info.index copiesRunning(index) -= 1 - var taskMetrics : TaskMetrics = null - + var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo] val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString val failureException: Option[Throwable] = reason match { @@ -669,7 +667,8 @@ private[spark] class TaskSetManager( None case ef: ExceptionFailure => - taskMetrics = ef.metrics.orNull + // ExceptionFailure's might have accumulator updates + accumUpdates = ef.accumUpdates if (ef.className == classOf[NotSerializableException].getName) { // If the task result wasn't serializable, there's no point in trying to re-execute it. logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying" @@ -721,7 +720,7 @@ private[spark] class TaskSetManager( // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) - sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) addPendingTask(index) if (!isZombie && state != TaskState.KILLED && reason.isInstanceOf[TaskFailedReason] @@ -793,7 +792,8 @@ private[spark] class TaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded( + tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info) } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index a57e5b0bfb86..acbe16001f5b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -103,8 +103,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 078718ba1126..9c92a501503c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -237,7 +237,8 @@ private[v1] object AllStagesResource { } def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = { - new AccumulableInfo(acc.id, acc.name, acc.update, acc.value) + new AccumulableInfo( + acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull) } def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 4a9f8b30525f..b2aa8bfbe700 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -325,12 +325,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { + val metrics = new TaskMetrics val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) + stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo, Some(metrics))) } for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); @@ -387,9 +388,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { (Some(e.toErrorString), None) } - if (!metrics.isEmpty) { + metrics.foreach { m => val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) - updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) @@ -489,19 +490,18 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) { val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { logWarning("Metrics update for task in unknown stage " + sid) new StageUIData }) val taskData = stageData.taskData.get(taskId) - taskData.map { t => + val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates) + taskData.foreach { t => if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, - t.taskMetrics) - + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics) // Overwrite task metrics - t.taskMetrics = Some(taskMetrics) + t.taskMetrics = Some(metrics) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 914f6183cc2a..29c5ff0b5cf0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -271,8 +271,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo): Elem = - {acc.name}{acc.value} + def accumulableRow(acc: AccumulableInfo): Seq[Node] = { + (acc.name, acc.value) match { + case (Some(name), Some(value)) => {name}{value} + case _ => Seq.empty[Node] + } + } val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, @@ -404,13 +408,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) - val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => - info.accumulables - .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.update.getOrElse("0").toLong } - .getOrElse(0L) - .toDouble - } + val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.peakExecutionMemory.toDouble + } val peakExecutionMemoryQuantiles = { - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") - } - val peakExecutionMemoryUsed = taskInternalAccumulables - .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.update.getOrElse("0").toLong } - .getOrElse(0L) + val externalAccumulableReadable = info.accumulables + .filterNot(_.internal) + .flatMap { a => + (a.name, a.update) match { + case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update")) + case _ => None + } + } + val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index efa22b99936a..dc8070cf8aad 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -233,14 +233,14 @@ private[spark] object JsonProtocol { def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId - val taskMetrics = metricsUpdate.taskMetrics + val accumUpdates = metricsUpdate.accumUpdates ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ ("Executor ID" -> execId) ~ - ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ ("Stage ID" -> stageId) ~ ("Stage Attempt ID" -> stageAttemptId) ~ - ("Task Metrics" -> taskMetricsToJson(metrics)) + ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList)) }) } @@ -265,7 +265,7 @@ private[spark] object JsonProtocol { ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -284,11 +284,44 @@ private[spark] object JsonProtocol { } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { + val name = accumulableInfo.name ("ID" -> accumulableInfo.id) ~ - ("Name" -> accumulableInfo.name) ~ - ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~ - ("Value" -> accumulableInfo.value) ~ - ("Internal" -> accumulableInfo.internal) + ("Name" -> name) ~ + ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~ + ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~ + ("Internal" -> accumulableInfo.internal) ~ + ("Count Failed Values" -> accumulableInfo.countFailedValues) + } + + /** + * Serialize the value of an accumulator to JSON. + * + * For accumulators representing internal task metrics, this looks up the relevant + * [[AccumulatorParam]] to serialize the value accordingly. For all other accumulators, + * this will simply serialize the value as a string. + * + * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing. + */ + private[util] def accumValueToJson(name: Option[String], value: Any): JValue = { + import AccumulatorParam._ + if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { + (value, InternalAccumulator.getParam(name.get)) match { + case (v: Int, IntAccumulatorParam) => JInt(v) + case (v: Long, LongAccumulatorParam) => JInt(v) + case (v: String, StringAccumulatorParam) => JString(v) + case (v, UpdatedBlockStatusesAccumulatorParam) => + JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) + }) + case (v, p) => + throw new IllegalArgumentException(s"unexpected combination of accumulator value " + + s"type (${v.getClass.getName}) and param (${p.getClass.getName}) in '${name.get}'") + } + } else { + // For all external accumulators, just use strings + JString(value.toString) + } } def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { @@ -303,9 +336,9 @@ private[spark] object JsonProtocol { }.getOrElse(JNothing) val shuffleWriteMetrics: JValue = taskMetrics.shuffleWriteMetrics.map { wm => - ("Shuffle Bytes Written" -> wm.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> wm.shuffleWriteTime) ~ - ("Shuffle Records Written" -> wm.shuffleRecordsWritten) + ("Shuffle Bytes Written" -> wm.bytesWritten) ~ + ("Shuffle Write Time" -> wm.writeTime) ~ + ("Shuffle Records Written" -> wm.recordsWritten) }.getOrElse(JNothing) val inputMetrics: JValue = taskMetrics.inputMetrics.map { im => @@ -324,7 +357,6 @@ private[spark] object JsonProtocol { ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) - ("Host Name" -> taskMetrics.hostname) ~ ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ ("Result Size" -> taskMetrics.resultSize) ~ @@ -352,12 +384,12 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing) + val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ - ("Metrics" -> metrics) + ("Accumulator Updates" -> accumUpdates) case taskCommitDenied: TaskCommitDenied => ("Job ID" -> taskCommitDenied.jobID) ~ ("Partition ID" -> taskCommitDenied.partitionID) ~ @@ -619,14 +651,15 @@ private[spark] object JsonProtocol { def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { val execInfo = (json \ "Executor ID").extract[String] - val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val accumUpdates = (json \ "Metrics Updated").extract[List[JValue]].map { json => val taskId = (json \ "Task ID").extract[Long] val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] - val metrics = taskMetricsFromJson(json \ "Task Metrics") - (taskId, stageId, stageAttemptId, metrics) + val updates = + (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson) + (taskId, stageId, stageAttemptId, updates) } - SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } /** --------------------------------------------------------------------- * @@ -647,7 +680,7 @@ private[spark] object JsonProtocol { val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson(_)) + case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -675,7 +708,7 @@ private[spark] object JsonProtocol { val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson(_)) + case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -690,11 +723,43 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = (json \ "Name").extract[String] - val update = Utils.jsonOption(json \ "Update").map(_.extract[String]) - val value = (json \ "Value").extract[String] + val name = (json \ "Name").extractOpt[String] + val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) - AccumulableInfo(id, name, update, value, internal) + val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) + new AccumulableInfo(id, name, update, value, internal, countFailedValues) + } + + /** + * Deserialize the value of an accumulator from JSON. + * + * For accumulators representing internal task metrics, this looks up the relevant + * [[AccumulatorParam]] to deserialize the value accordingly. For all other + * accumulators, this will simply deserialize the value as a string. + * + * The behavior here must match that of [[accumValueToJson]]. Exposed for testing. + */ + private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = { + import AccumulatorParam._ + if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { + (value, InternalAccumulator.getParam(name.get)) match { + case (JInt(v), IntAccumulatorParam) => v.toInt + case (JInt(v), LongAccumulatorParam) => v.toLong + case (JString(v), StringAccumulatorParam) => v + case (JArray(v), UpdatedBlockStatusesAccumulatorParam) => + v.map { blockJson => + val id = BlockId((blockJson \ "Block ID").extract[String]) + val status = blockStatusFromJson(blockJson \ "Status") + (id, status) + } + case (v, p) => + throw new IllegalArgumentException(s"unexpected combination of accumulator " + + s"value in JSON ($v) and accumulator param (${p.getClass.getName}) in '${name.get}'") + } + } else { + value.extract[String] + } } def taskMetricsFromJson(json: JValue): TaskMetrics = { @@ -702,7 +767,6 @@ private[spark] object JsonProtocol { return TaskMetrics.empty } val metrics = new TaskMetrics - metrics.setHostname((json \ "Host Name").extract[String]) metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long]) metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long]) metrics.setResultSize((json \ "Result Size").extract[Long]) @@ -787,10 +851,12 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") - val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). - map(_.extract[String]).orNull - val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) + val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull + // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x + val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates()) + ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `taskCommitDenied` => diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index df9e0502e736..5afd6d6e22c6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -682,8 +682,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) lengths } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 625fdd57eb5d..876c3a228364 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -191,8 +191,6 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); - when(taskContext.internalMetricsToAccumulators()).thenReturn(null); - when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 5b84acf40be4..11c97d7d9a44 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -17,18 +17,22 @@ package org.apache.spark +import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference +import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException import org.apache.spark.scheduler._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { - import InternalAccumulator._ + import AccumulatorParam._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -59,7 +63,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex longAcc.value should be (210L + maxInt * 20) } - test ("value not assignable from tasks") { + test("value not assignable from tasks") { sc = new SparkContext("local", "test") val acc : Accumulator[Int] = sc.accumulator(0) @@ -84,7 +88,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - test ("value not readable in tasks") { + test("value not readable in tasks") { val maxI = 1000 for (nThreads <- List(1, 10)) { // test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -159,193 +163,157 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } - test("internal accumulators in TaskContext") { + test("get accum") { sc = new SparkContext("local", "test") - val accums = InternalAccumulator.create(sc) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) - val internalMetricsToAccums = taskContext.internalMetricsToAccumulators - val collectedInternalAccums = taskContext.collectInternalAccumulators() - val collectedAccums = taskContext.collectAccumulators() - assert(internalMetricsToAccums.size > 0) - assert(internalMetricsToAccums.values.forall(_.isInternal)) - assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) - val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) - assert(collectedInternalAccums.size === internalMetricsToAccums.size) - assert(collectedInternalAccums.size === collectedAccums.size) - assert(collectedInternalAccums.contains(testAccum.id)) - assert(collectedAccums.contains(testAccum.id)) - } + // Don't register with SparkContext for cleanup + var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true) + val accId = acc.id + val ref = WeakReference(acc) + assert(ref.get.isDefined) + Accumulators.register(ref.get.get) - test("internal accumulators in a stage") { - val listener = new SaveInfoListener - val numPartitions = 10 - sc = new SparkContext("local", "test") - sc.addSparkListener(listener) - // Have each task add 1 to the internal accumulator - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions) - // The accumulator values should be merged in the stage - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - assert(stageAccum.value.toLong === numPartitions) - // The accumulator should be updated locally on each task - val taskAccumValues = taskInfos.map { taskInfo => - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - taskAccum.value.toLong - } - // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + // Getting a garbage collected accum should throw error + intercept[IllegalAccessError] { + Accumulators.get(accId) } - rdd.count() + + // Getting a normal accumulator. Note: this has to be separate because referencing an + // accumulator above in an `assert` would keep it from being garbage collected. + val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true, true) + Accumulators.register(acc2) + assert(Accumulators.get(acc2.id) === Some(acc2)) + + // Getting an accumulator that does not exist should return None + assert(Accumulators.get(100000).isEmpty) } - test("internal accumulators in multiple stages") { - val listener = new SaveInfoListener - val numPartitions = 10 - sc = new SparkContext("local", "test") - sc.addSparkListener(listener) - // Each stage creates its own set of internal accumulators so the - // values for the same metric should not be mixed up across stages - val rdd = sc.parallelize(1 to 100, numPartitions) - .map { i => (i, i) } - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - iter - } - .reduceByKey { case (x, y) => x + y } - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 - iter - } - .repartition(numPartitions * 2) - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - // We ran 3 stages, and the accumulator values should be distinct - val stageInfos = listener.getCompletedStageInfos - assert(stageInfos.size === 3) - val (firstStageAccum, secondStageAccum, thirdStageAccum) = - (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), - findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), - findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) - assert(firstStageAccum.value.toLong === numPartitions) - assert(secondStageAccum.value.toLong === numPartitions * 10) - assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) - } - rdd.count() + test("only external accums are automatically registered") { + val accEx = new Accumulator(0, IntAccumulatorParam, Some("external"), internal = false) + val accIn = new Accumulator(0, IntAccumulatorParam, Some("internal"), internal = true) + assert(!accEx.isInternal) + assert(accIn.isInternal) + assert(Accumulators.get(accEx.id).isDefined) + assert(Accumulators.get(accIn.id).isEmpty) } - test("internal accumulators in fully resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + test("copy") { + val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), true, false) + val acc2 = acc1.copy() + assert(acc1.id === acc2.id) + assert(acc1.value === acc2.value) + assert(acc1.name === acc2.name) + assert(acc1.isInternal === acc2.isInternal) + assert(acc1.countFailedValues === acc2.countFailedValues) + assert(acc1 !== acc2) + // Modifying one does not affect the other + acc1.add(44L) + assert(acc1.value === 500L) + assert(acc2.value === 456L) + acc2.add(144L) + assert(acc1.value === 500L) + assert(acc2.value === 600L) } - test("internal accumulators in partially resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + test("register multiple accums with same ID") { + // Make sure these are internal accums so we don't automatically register them already + val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true) + val acc2 = acc1.copy() + assert(acc1 !== acc2) + assert(acc1.id === acc2.id) + assert(Accumulators.originals.isEmpty) + assert(Accumulators.get(acc1.id).isEmpty) + Accumulators.register(acc1) + Accumulators.register(acc2) + // The second one does not override the first one + assert(Accumulators.originals.size === 1) + assert(Accumulators.get(acc1.id) === Some(acc1)) } - /** - * Return the accumulable info that matches the specified name. - */ - private def findAccumulableInfo( - accums: Iterable[AccumulableInfo], - name: String): AccumulableInfo = { - accums.find { a => a.name == name }.getOrElse { - throw new TestFailedException(s"internal accumulator '$name' not found", 0) - } + test("string accumulator param") { + val acc = new Accumulator("", StringAccumulatorParam, Some("darkness")) + assert(acc.value === "") + acc.setValue("feeds") + assert(acc.value === "feeds") + acc.add("your") + assert(acc.value === "your") // value is overwritten, not concatenated + acc += "soul" + assert(acc.value === "soul") + acc ++= "with" + assert(acc.value === "with") + acc.merge("kindness") + assert(acc.value === "kindness") } - /** - * Test whether internal accumulators are merged properly if some tasks fail. - */ - private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { - val listener = new SaveInfoListener - val numPartitions = 10 - val numFailedPartitions = (0 until numPartitions).count(failCondition) - // This says use 1 core and retry tasks up to 2 times - sc = new SparkContext("local[1, 2]", "test") - sc.addSparkListener(listener) - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => - val taskContext = TaskContext.get() - taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - // Fail the first attempts of a subset of the tasks - if (failCondition(i) && taskContext.attemptNumber() == 0) { - throw new Exception("Failing a task intentionally.") - } - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - // We should not double count values in the merged accumulator - assert(stageAccum.value.toLong === numPartitions) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - Some(taskAccum.value.toLong) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None - } - } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) - } - rdd.count() + test("list accumulator param") { + val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers")) + assert(acc.value === Seq.empty[Int]) + acc.add(Seq(1, 2)) + assert(acc.value === Seq(1, 2)) + acc += Seq(3, 4) + assert(acc.value === Seq(1, 2, 3, 4)) + acc ++= Seq(5, 6) + assert(acc.value === Seq(1, 2, 3, 4, 5, 6)) + acc.merge(Seq(7, 8)) + assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8)) + acc.setValue(Seq(9, 10)) + assert(acc.value === Seq(9, 10)) + } + + test("value is reset on the executors") { + val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false) + val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false) + val externalAccums = Seq(acc1, acc2) + val internalAccums = InternalAccumulator.create() + // Set some values; these should not be observed later on the "executors" + acc1.setValue(10) + acc2.setValue(20L) + internalAccums + .find(_.name == Some(InternalAccumulator.TEST_ACCUM)) + .get.asInstanceOf[Accumulator[Long]] + .setValue(30L) + // Simulate the task being serialized and sent to the executors. + val dummyTask = new DummyTask(internalAccums, externalAccums) + val serInstance = new JavaSerializer(new SparkConf).newInstance() + val taskSer = Task.serializeWithDependencies( + dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance) + // Now we're on the executors. + // Deserialize the task and assert that its accumulators are zero'ed out. + val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer) + val taskDeser = serInstance.deserialize[DummyTask]( + taskBytes, Thread.currentThread.getContextClassLoader) + // Assert that executors see only zeros + taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) } + taskDeser.internalAccums.foreach { a => assert(a.localValue == a.zero) } } } private[spark] object AccumulatorSuite { + import InternalAccumulator._ + /** - * Run one or more Spark jobs and verify that the peak execution memory accumulator - * is updated afterwards. + * Run one or more Spark jobs and verify that in at least one job the peak execution memory + * accumulator is updated afterwards. */ def verifyPeakExecutionMemorySet( sc: SparkContext, testName: String)(testBody: => Unit): Unit = { val listener = new SaveInfoListener sc.addSparkListener(listener) - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { jobId => - if (jobId == 0) { - // The first job is a dummy one to verify that the accumulator does not already exist - val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) - assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) - } else { - // In the subsequent jobs, verify that peak execution memory is updated - val accum = listener.getCompletedStageInfos - .flatMap(_.accumulables.values) - .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) - .getOrElse { - throw new TestFailedException( - s"peak execution memory accumulator not set in '$testName'", 0) - } - assert(accum.value.toLong > 0) - } - } - // Run the jobs - sc.parallelize(1 to 10).count() testBody + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + val isSet = accums.exists { a => + a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) + } + if (!isSet) { + throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0) + } } } @@ -357,6 +325,10 @@ private class SaveInfoListener extends SparkListener { private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated + @GuardedBy("this") + private var exception: Throwable = null + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq @@ -365,9 +337,20 @@ private class SaveInfoListener extends SparkListener { jobCompletionCallback = callback } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + /** Throw a stored exception, if any. */ + def maybeThrowException(): Unit = synchronized { + if (exception != null) { throw exception } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { if (jobCompletionCallback != null) { - jobCompletionCallback(jobEnd.jobId) + try { + jobCompletionCallback(jobEnd.jobId) + } catch { + // Store any exception thrown here so we can throw them later in the main thread. + // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. + case NonFatal(e) => exception = e + } } } @@ -379,3 +362,14 @@ private class SaveInfoListener extends SparkListener { completedTaskInfos += taskEnd.taskInfo } } + + +/** + * A dummy [[Task]] that contains internal and external [[Accumulator]]s. + */ +private[spark] class DummyTask( + val internalAccums: Seq[Accumulator[_]], + val externalAccums: Seq[Accumulator[_]]) + extends Task[Int](0, 0, 0, internalAccums) { + override def runTask(c: TaskContext): Int = 1 +} diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 4e678fbac6a3..80a1de6065b4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -801,7 +801,7 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. - val taskEndReason = ExceptionFailure(null, null, null, null, null, None) + val taskEndReason = ExceptionFailure(null, null, null, null, None) sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index c7f629a14ba2..3777d77f8f5b 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -215,14 +215,16 @@ class HeartbeatReceiverSuite val metrics = new TaskMetrics val blockManagerId = BlockManagerId(executorId, "localhost", 12345) val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId)) if (executorShouldReregister) { assert(response.reregisterBlockManager) } else { assert(!response.reregisterBlockManager) // Additionally verify that the scheduler callback is called with the correct parameters verify(scheduler).executorHeartbeatReceived( - Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + Matchers.eq(executorId), + Matchers.eq(Array(1L -> metrics.accumulatorUpdates())), + Matchers.eq(blockManagerId)) } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala new file mode 100644 index 000000000000..630b46f828df --- /dev/null +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.storage.{BlockId, BlockStatus} + + +class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { + import InternalAccumulator._ + import AccumulatorParam._ + + test("get param") { + assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam) + assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam) + assert(getParam(RESULT_SIZE) === LongAccumulatorParam) + assert(getParam(JVM_GC_TIME) === LongAccumulatorParam) + assert(getParam(RESULT_SERIALIZATION_TIME) === LongAccumulatorParam) + assert(getParam(MEMORY_BYTES_SPILLED) === LongAccumulatorParam) + assert(getParam(DISK_BYTES_SPILLED) === LongAccumulatorParam) + assert(getParam(PEAK_EXECUTION_MEMORY) === LongAccumulatorParam) + assert(getParam(UPDATED_BLOCK_STATUSES) === UpdatedBlockStatusesAccumulatorParam) + assert(getParam(TEST_ACCUM) === LongAccumulatorParam) + // shuffle read + assert(getParam(shuffleRead.REMOTE_BLOCKS_FETCHED) === IntAccumulatorParam) + assert(getParam(shuffleRead.LOCAL_BLOCKS_FETCHED) === IntAccumulatorParam) + assert(getParam(shuffleRead.REMOTE_BYTES_READ) === LongAccumulatorParam) + assert(getParam(shuffleRead.LOCAL_BYTES_READ) === LongAccumulatorParam) + assert(getParam(shuffleRead.FETCH_WAIT_TIME) === LongAccumulatorParam) + assert(getParam(shuffleRead.RECORDS_READ) === LongAccumulatorParam) + // shuffle write + assert(getParam(shuffleWrite.BYTES_WRITTEN) === LongAccumulatorParam) + assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam) + assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam) + // input + assert(getParam(input.READ_METHOD) === StringAccumulatorParam) + assert(getParam(input.RECORDS_READ) === LongAccumulatorParam) + assert(getParam(input.BYTES_READ) === LongAccumulatorParam) + // output + assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam) + assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam) + assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam) + // default to Long + assert(getParam(METRICS_PREFIX + "anything") === LongAccumulatorParam) + intercept[IllegalArgumentException] { + getParam("something that does not start with the right prefix") + } + } + + test("create by name") { + val executorRunTime = create(EXECUTOR_RUN_TIME) + val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES) + val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED) + val inputReadMethod = create(input.READ_METHOD) + assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME)) + assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES)) + assert(shuffleRemoteBlocksRead.name === Some(shuffleRead.REMOTE_BLOCKS_FETCHED)) + assert(inputReadMethod.name === Some(input.READ_METHOD)) + assert(executorRunTime.value.isInstanceOf[Long]) + assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]]) + // We cannot assert the type of the value directly since the type parameter is erased. + // Instead, try casting a `Seq` of expected type and see if it fails in run time. + updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)]) + assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int]) + assert(inputReadMethod.value.isInstanceOf[String]) + // default to Long + val anything = create(METRICS_PREFIX + "anything") + assert(anything.value.isInstanceOf[Long]) + } + + test("create") { + val accums = create() + val shuffleReadAccums = createShuffleReadAccums() + val shuffleWriteAccums = createShuffleWriteAccums() + val inputAccums = createInputAccums() + val outputAccums = createOutputAccums() + // assert they're all internal + assert(accums.forall(_.isInternal)) + assert(shuffleReadAccums.forall(_.isInternal)) + assert(shuffleWriteAccums.forall(_.isInternal)) + assert(inputAccums.forall(_.isInternal)) + assert(outputAccums.forall(_.isInternal)) + // assert they all count on failures + assert(accums.forall(_.countFailedValues)) + assert(shuffleReadAccums.forall(_.countFailedValues)) + assert(shuffleWriteAccums.forall(_.countFailedValues)) + assert(inputAccums.forall(_.countFailedValues)) + assert(outputAccums.forall(_.countFailedValues)) + // assert they all have names + assert(accums.forall(_.name.isDefined)) + assert(shuffleReadAccums.forall(_.name.isDefined)) + assert(shuffleWriteAccums.forall(_.name.isDefined)) + assert(inputAccums.forall(_.name.isDefined)) + assert(outputAccums.forall(_.name.isDefined)) + // assert `accums` is a strict superset of the others + val accumNames = accums.map(_.name.get).toSet + val shuffleReadAccumNames = shuffleReadAccums.map(_.name.get).toSet + val shuffleWriteAccumNames = shuffleWriteAccums.map(_.name.get).toSet + val inputAccumNames = inputAccums.map(_.name.get).toSet + val outputAccumNames = outputAccums.map(_.name.get).toSet + assert(shuffleReadAccumNames.subsetOf(accumNames)) + assert(shuffleWriteAccumNames.subsetOf(accumNames)) + assert(inputAccumNames.subsetOf(accumNames)) + assert(outputAccumNames.subsetOf(accumNames)) + } + + test("naming") { + val accums = create() + val shuffleReadAccums = createShuffleReadAccums() + val shuffleWriteAccums = createShuffleWriteAccums() + val inputAccums = createInputAccums() + val outputAccums = createOutputAccums() + // assert that prefixes are properly namespaced + assert(SHUFFLE_READ_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(SHUFFLE_WRITE_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(INPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(OUTPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(accums.forall(_.name.get.startsWith(METRICS_PREFIX))) + // assert they all start with the expected prefixes + assert(shuffleReadAccums.forall(_.name.get.startsWith(SHUFFLE_READ_METRICS_PREFIX))) + assert(shuffleWriteAccums.forall(_.name.get.startsWith(SHUFFLE_WRITE_METRICS_PREFIX))) + assert(inputAccums.forall(_.name.get.startsWith(INPUT_METRICS_PREFIX))) + assert(outputAccums.forall(_.name.get.startsWith(OUTPUT_METRICS_PREFIX))) + } + + test("internal accumulators in TaskContext") { + val taskContext = TaskContext.empty() + val accumUpdates = taskContext.taskMetrics.accumulatorUpdates() + assert(accumUpdates.size > 0) + assert(accumUpdates.forall(_.internal)) + val testAccum = taskContext.taskMetrics.getAccum(TEST_ACCUM) + assert(accumUpdates.exists(_.id == testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findTestAccum(stageInfos.head.accumulables.values) + assert(stageAccum.value.get.asInstanceOf[Long] === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findTestAccum(taskInfo.accumulables) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.asInstanceOf[Long] === 1L) + taskAccum.value.get.asInstanceOf[Long] + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + val rdd = sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 100 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findTestAccum(stageInfos(0).accumulables.values), + findTestAccum(stageInfos(1).accumulables.values), + findTestAccum(stageInfos(2).accumulables.values)) + assert(firstStageAccum.value.get.asInstanceOf[Long] === numPartitions) + assert(secondStageAccum.value.get.asInstanceOf[Long] === numPartitions * 10) + assert(thirdStageAccum.value.get.asInstanceOf[Long] === numPartitions * 2 * 100) + } + rdd.count() + } + + // TODO: these two tests are incorrect; they don't actually trigger stage retries. + ignore("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + ignore("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + test("internal accumulators are registered for cleanups") { + sc = new SparkContext("local", "test") { + private val myCleaner = new SaveAccumContextCleaner(this) + override def cleaner: Option[ContextCleaner] = Some(myCleaner) + } + assert(Accumulators.originals.isEmpty) + sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count() + val internalAccums = InternalAccumulator.create() + // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage + assert(Accumulators.originals.size === internalAccums.size * 2) + val accumsRegistered = sc.cleaner match { + case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup + case _ => Seq.empty[Long] + } + // Make sure the same set of accumulators is registered for cleanup + assert(accumsRegistered.size === internalAccums.size * 2) + assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet) + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findTestAccum(accums: Iterable[AccumulableInfo]): AccumulableInfo = { + accums.find { a => a.name == Some(TEST_ACCUM) }.getOrElse { + fail(s"unable to find internal accumulator called $TEST_ACCUM") + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + * TODO: make this actually retry the stage. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findTestAccum(stageInfos.head.accumulables.values) + // If all partitions failed, then we would resubmit the whole stage again and create a + // fresh set of internal accumulators. Otherwise, these internal accumulators do count + // failed values, so we must include the failed values. + val expectedAccumValue = + if (numPartitions == numFailedPartitions) { + numPartitions + } else { + numPartitions + numFailedPartitions + } + assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findTestAccum(taskInfo.accumulables) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.asInstanceOf[Long] === 1L) + assert(taskAccum.value.isDefined) + Some(taskAccum.value.get.asInstanceOf[Long]) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + listener.maybeThrowException() + } + + /** + * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. + */ + private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { + private val accumsRegistered = new ArrayBuffer[Long] + + override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + accumsRegistered += a.id + super.registerAccumulatorForCleanup(a) + } + + def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray + } + +} diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 9be9db01c7de..d3359c7406e4 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -42,6 +42,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { test() } finally { logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") + // Avoid leaking map entries in tests that use accumulators without SparkContext + Accumulators.clear() } } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index e5ec2aa1be35..15be0b194ed8 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -17,12 +17,542 @@ package org.apache.spark.executor -import org.apache.spark.SparkFunSuite +import org.scalatest.Assertions + +import org.apache.spark._ +import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId} + class TaskMetricsSuite extends SparkFunSuite { - test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { - val taskMetrics = new TaskMetrics() - taskMetrics.mergeShuffleReadMetrics() - assert(taskMetrics.shuffleReadMetrics.isEmpty) + import AccumulatorParam._ + import InternalAccumulator._ + import StorageLevel._ + import TaskMetricsSuite._ + + test("create") { + val internalAccums = InternalAccumulator.create() + val tm1 = new TaskMetrics + val tm2 = new TaskMetrics(internalAccums) + assert(tm1.accumulatorUpdates().size === internalAccums.size) + assert(tm1.shuffleReadMetrics.isEmpty) + assert(tm1.shuffleWriteMetrics.isEmpty) + assert(tm1.inputMetrics.isEmpty) + assert(tm1.outputMetrics.isEmpty) + assert(tm2.accumulatorUpdates().size === internalAccums.size) + assert(tm2.shuffleReadMetrics.isEmpty) + assert(tm2.shuffleWriteMetrics.isEmpty) + assert(tm2.inputMetrics.isEmpty) + assert(tm2.outputMetrics.isEmpty) + // TaskMetrics constructor expects minimal set of initial accumulators + intercept[IllegalArgumentException] { new TaskMetrics(Seq.empty[Accumulator[_]]) } + } + + test("create with unnamed accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.create() ++ Seq( + new Accumulator(0, IntAccumulatorParam, None, internal = true))) + } + } + + test("create with duplicate name accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.create() ++ Seq( + new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true))) + } + } + + test("create with external accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.create() ++ Seq( + new Accumulator(0, IntAccumulatorParam, Some("x")))) + } + } + + test("create shuffle read metrics") { + import shuffleRead._ + val accums = InternalAccumulator.createShuffleReadAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(REMOTE_BLOCKS_FETCHED).setValueAny(1) + accums(LOCAL_BLOCKS_FETCHED).setValueAny(2) + accums(REMOTE_BYTES_READ).setValueAny(3L) + accums(LOCAL_BYTES_READ).setValueAny(4L) + accums(FETCH_WAIT_TIME).setValueAny(5L) + accums(RECORDS_READ).setValueAny(6L) + val sr = new ShuffleReadMetrics(accums) + assert(sr.remoteBlocksFetched === 1) + assert(sr.localBlocksFetched === 2) + assert(sr.remoteBytesRead === 3L) + assert(sr.localBytesRead === 4L) + assert(sr.fetchWaitTime === 5L) + assert(sr.recordsRead === 6L) + } + + test("create shuffle write metrics") { + import shuffleWrite._ + val accums = InternalAccumulator.createShuffleWriteAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_WRITTEN).setValueAny(1L) + accums(RECORDS_WRITTEN).setValueAny(2L) + accums(WRITE_TIME).setValueAny(3L) + val sw = new ShuffleWriteMetrics(accums) + assert(sw.bytesWritten === 1L) + assert(sw.recordsWritten === 2L) + assert(sw.writeTime === 3L) + } + + test("create input metrics") { + import input._ + val accums = InternalAccumulator.createInputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_READ).setValueAny(1L) + accums(RECORDS_READ).setValueAny(2L) + accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString) + val im = new InputMetrics(accums) + assert(im.bytesRead === 1L) + assert(im.recordsRead === 2L) + assert(im.readMethod === DataReadMethod.Hadoop) + } + + test("create output metrics") { + import output._ + val accums = InternalAccumulator.createOutputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_WRITTEN).setValueAny(1L) + accums(RECORDS_WRITTEN).setValueAny(2L) + accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString) + val om = new OutputMetrics(accums) + assert(om.bytesWritten === 1L) + assert(om.recordsWritten === 2L) + assert(om.writeMethod === DataWriteMethod.Hadoop) + } + + test("mutating values") { + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + // initial values + assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L) + assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 0L) + assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 0L) + assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 0L) + assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 0L) + assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 0L) + assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 0L) + assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 0L) + assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, + Seq.empty[(BlockId, BlockStatus)]) + // set or increment values + tm.setExecutorDeserializeTime(100L) + tm.setExecutorDeserializeTime(1L) // overwrite + tm.setExecutorRunTime(200L) + tm.setExecutorRunTime(2L) + tm.setResultSize(300L) + tm.setResultSize(3L) + tm.setJvmGCTime(400L) + tm.setJvmGCTime(4L) + tm.setResultSerializationTime(500L) + tm.setResultSerializationTime(5L) + tm.incMemoryBytesSpilled(600L) + tm.incMemoryBytesSpilled(6L) // add + tm.incDiskBytesSpilled(700L) + tm.incDiskBytesSpilled(7L) + tm.incPeakExecutionMemory(800L) + tm.incPeakExecutionMemory(8L) + val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L)) + val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L)) + tm.incUpdatedBlockStatuses(Seq(block1)) + tm.incUpdatedBlockStatuses(Seq(block2)) + // assert new values exist + assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 1L) + assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 2L) + assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 3L) + assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 4L) + assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 5L) + assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 606L) + assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 707L) + assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 808L) + assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, + Seq(block1, block2)) + } + + test("mutating shuffle read metrics values") { + import shuffleRead._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = { + assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value) + } + // create shuffle read metrics + assert(tm.shuffleReadMetrics.isEmpty) + tm.registerTempShuffleReadMetrics() + tm.mergeShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isDefined) + val sr = tm.shuffleReadMetrics.get + // initial values + assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0) + assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0) + assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 0L) + assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 0L) + assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 0L) + assertValEquals(_.recordsRead, RECORDS_READ, 0L) + // set and increment values + sr.setRemoteBlocksFetched(100) + sr.setRemoteBlocksFetched(10) + sr.incRemoteBlocksFetched(1) // 10 + 1 + sr.incRemoteBlocksFetched(1) // 10 + 1 + 1 + sr.setLocalBlocksFetched(200) + sr.setLocalBlocksFetched(20) + sr.incLocalBlocksFetched(2) + sr.incLocalBlocksFetched(2) + sr.setRemoteBytesRead(300L) + sr.setRemoteBytesRead(30L) + sr.incRemoteBytesRead(3L) + sr.incRemoteBytesRead(3L) + sr.setLocalBytesRead(400L) + sr.setLocalBytesRead(40L) + sr.incLocalBytesRead(4L) + sr.incLocalBytesRead(4L) + sr.setFetchWaitTime(500L) + sr.setFetchWaitTime(50L) + sr.incFetchWaitTime(5L) + sr.incFetchWaitTime(5L) + sr.setRecordsRead(600L) + sr.setRecordsRead(60L) + sr.incRecordsRead(6L) + sr.incRecordsRead(6L) + // assert new values exist + assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 12) + assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 24) + assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 36L) + assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 48L) + assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 60L) + assertValEquals(_.recordsRead, RECORDS_READ, 72L) + } + + test("mutating shuffle write metrics values") { + import shuffleWrite._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = { + assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value) + } + // create shuffle write metrics + assert(tm.shuffleWriteMetrics.isEmpty) + tm.registerShuffleWriteMetrics() + assert(tm.shuffleWriteMetrics.isDefined) + val sw = tm.shuffleWriteMetrics.get + // initial values + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) + assertValEquals(_.writeTime, WRITE_TIME, 0L) + // increment and decrement values + sw.incBytesWritten(100L) + sw.incBytesWritten(10L) // 100 + 10 + sw.decBytesWritten(1L) // 100 + 10 - 1 + sw.decBytesWritten(1L) // 100 + 10 - 1 - 1 + sw.incRecordsWritten(200L) + sw.incRecordsWritten(20L) + sw.decRecordsWritten(2L) + sw.decRecordsWritten(2L) + sw.incWriteTime(300L) + sw.incWriteTime(30L) + // assert new values exist + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 108L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 216L) + assertValEquals(_.writeTime, WRITE_TIME, 330L) + } + + test("mutating input metrics values") { + import input._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = { + assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value, + (x: Any, y: Any) => assert(x.toString === y.toString)) + } + // create input metrics + assert(tm.inputMetrics.isEmpty) + tm.registerInputMetrics(DataReadMethod.Memory) + assert(tm.inputMetrics.isDefined) + val in = tm.inputMetrics.get + // initial values + assertValEquals(_.bytesRead, BYTES_READ, 0L) + assertValEquals(_.recordsRead, RECORDS_READ, 0L) + assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory) + // set and increment values + in.setBytesRead(1L) + in.setBytesRead(2L) + in.incRecordsRead(1L) + in.incRecordsRead(2L) + in.setReadMethod(DataReadMethod.Disk) + // assert new values exist + assertValEquals(_.bytesRead, BYTES_READ, 2L) + assertValEquals(_.recordsRead, RECORDS_READ, 3L) + assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk) + } + + test("mutating output metrics values") { + import output._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = { + assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value, + (x: Any, y: Any) => assert(x.toString === y.toString)) + } + // create input metrics + assert(tm.outputMetrics.isEmpty) + tm.registerOutputMetrics(DataWriteMethod.Hadoop) + assert(tm.outputMetrics.isDefined) + val out = tm.outputMetrics.get + // initial values + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) + assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + // set values + out.setBytesWritten(1L) + out.setBytesWritten(2L) + out.setRecordsWritten(3L) + out.setRecordsWritten(4L) + out.setWriteMethod(DataWriteMethod.Hadoop) + // assert new values exist + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L) + // Note: this doesn't actually test anything, but there's only one DataWriteMethod + // so we can't set it to anything else + assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + } + + test("merging multiple shuffle read metrics") { + val tm = new TaskMetrics + assert(tm.shuffleReadMetrics.isEmpty) + val sr1 = tm.registerTempShuffleReadMetrics() + val sr2 = tm.registerTempShuffleReadMetrics() + val sr3 = tm.registerTempShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isEmpty) + sr1.setRecordsRead(10L) + sr2.setRecordsRead(10L) + sr1.setFetchWaitTime(1L) + sr2.setFetchWaitTime(2L) + sr3.setFetchWaitTime(3L) + tm.mergeShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isDefined) + val sr = tm.shuffleReadMetrics.get + assert(sr.remoteBlocksFetched === 0L) + assert(sr.recordsRead === 20L) + assert(sr.fetchWaitTime === 6L) + + // SPARK-5701: calling merge without any shuffle deps does nothing + val tm2 = new TaskMetrics + tm2.mergeShuffleReadMetrics() + assert(tm2.shuffleReadMetrics.isEmpty) + } + + test("register multiple shuffle write metrics") { + val tm = new TaskMetrics + val sw1 = tm.registerShuffleWriteMetrics() + val sw2 = tm.registerShuffleWriteMetrics() + assert(sw1 === sw2) + assert(tm.shuffleWriteMetrics === Some(sw1)) + } + + test("register multiple input metrics") { + val tm = new TaskMetrics + val im1 = tm.registerInputMetrics(DataReadMethod.Memory) + val im2 = tm.registerInputMetrics(DataReadMethod.Memory) + // input metrics with a different read method than the one already registered are ignored + val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop) + assert(im1 === im2) + assert(im1 !== im3) + assert(tm.inputMetrics === Some(im1)) + im2.setBytesRead(50L) + im3.setBytesRead(100L) + assert(tm.inputMetrics.get.bytesRead === 50L) + } + + test("register multiple output metrics") { + val tm = new TaskMetrics + val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) + val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) + assert(om1 === om2) + assert(tm.outputMetrics === Some(om1)) + } + + test("additional accumulables") { + val internalAccums = InternalAccumulator.create() + val tm = new TaskMetrics(internalAccums) + assert(tm.accumulatorUpdates().size === internalAccums.size) + val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a")) + val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b")) + val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c")) + val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"), + internal = true, countFailedValues = true) + tm.registerAccumulator(acc1) + tm.registerAccumulator(acc2) + tm.registerAccumulator(acc3) + tm.registerAccumulator(acc4) + acc1 += 1 + acc2 += 2 + val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap + assert(newUpdates.contains(acc1.id)) + assert(newUpdates.contains(acc2.id)) + assert(newUpdates.contains(acc3.id)) + assert(newUpdates.contains(acc4.id)) + assert(newUpdates(acc1.id).name === Some("a")) + assert(newUpdates(acc2.id).name === Some("b")) + assert(newUpdates(acc3.id).name === Some("c")) + assert(newUpdates(acc4.id).name === Some("d")) + assert(newUpdates(acc1.id).update === Some(1)) + assert(newUpdates(acc2.id).update === Some(2)) + assert(newUpdates(acc3.id).update === Some(0)) + assert(newUpdates(acc4.id).update === Some(0)) + assert(!newUpdates(acc3.id).internal) + assert(!newUpdates(acc3.id).countFailedValues) + assert(newUpdates(acc4.id).internal) + assert(newUpdates(acc4.id).countFailedValues) + assert(newUpdates.values.map(_.update).forall(_.isDefined)) + assert(newUpdates.values.map(_.value).forall(_.isEmpty)) + assert(newUpdates.size === internalAccums.size + 4) + } + + test("existing values in shuffle read accums") { + // set shuffle read accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME)) + assert(srAccum.isDefined) + srAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isDefined) + assert(tm.shuffleWriteMetrics.isEmpty) + assert(tm.inputMetrics.isEmpty) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in shuffle write accums") { + // set shuffle write accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN)) + assert(swAccum.isDefined) + swAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isEmpty) + assert(tm.shuffleWriteMetrics.isDefined) + assert(tm.inputMetrics.isEmpty) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in input accums") { + // set input accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val inAccum = accums.find(_.name === Some(input.RECORDS_READ)) + assert(inAccum.isDefined) + inAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isEmpty) + assert(tm.shuffleWriteMetrics.isEmpty) + assert(tm.inputMetrics.isDefined) + assert(tm.outputMetrics.isEmpty) } + + test("existing values in output accums") { + // set output accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN)) + assert(outAccum.isDefined) + outAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm4 = new TaskMetrics(accums) + assert(tm4.shuffleReadMetrics.isEmpty) + assert(tm4.shuffleWriteMetrics.isEmpty) + assert(tm4.inputMetrics.isEmpty) + assert(tm4.outputMetrics.isDefined) + } + + test("from accumulator updates") { + val accumUpdates1 = InternalAccumulator.create().map { a => + AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues) + } + val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1) + assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1) + // Test this with additional accumulators. Only the ones registered with `Accumulators` + // will show up in the reconstructed TaskMetrics. In practice, all accumulators created + // on the driver, internal or not, should be registered with `Accumulators` at some point. + // Here we show that reconstruction will succeed even if there are unregistered accumulators. + val param = IntAccumulatorParam + val registeredAccums = Seq( + new Accumulator(0, param, Some("a"), internal = true, countFailedValues = true), + new Accumulator(0, param, Some("b"), internal = true, countFailedValues = false), + new Accumulator(0, param, Some("c"), internal = false, countFailedValues = true), + new Accumulator(0, param, Some("d"), internal = false, countFailedValues = false)) + val unregisteredAccums = Seq( + new Accumulator(0, param, Some("e"), internal = true, countFailedValues = true), + new Accumulator(0, param, Some("f"), internal = true, countFailedValues = false)) + registeredAccums.foreach(Accumulators.register) + registeredAccums.foreach { a => assert(Accumulators.originals.contains(a.id)) } + unregisteredAccums.foreach { a => assert(!Accumulators.originals.contains(a.id)) } + // set some values in these accums + registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) } + unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) } + val registeredAccumInfos = registeredAccums.map(makeInfo) + val unregisteredAccumInfos = unregisteredAccums.map(makeInfo) + val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos + val metrics2 = TaskMetrics.fromAccumulatorUpdates(accumUpdates2) + // accumulators that were not registered with `Accumulators` will not show up + assertUpdatesEquals(metrics2.accumulatorUpdates(), accumUpdates1 ++ registeredAccumInfos) + } +} + + +private[spark] object TaskMetricsSuite extends Assertions { + + /** + * Assert that the following three things are equal to `value`: + * (1) TaskMetrics value + * (2) TaskMetrics accumulator update value + * (3) Original accumulator value + */ + def assertValueEquals( + tm: TaskMetrics, + tmValue: TaskMetrics => Any, + accums: Seq[Accumulator[_]], + metricName: String, + value: Any, + assertEquals: (Any, Any) => Unit = (x: Any, y: Any) => assert(x === y)): Unit = { + assertEquals(tmValue(tm), value) + val accum = accums.find(_.name == Some(metricName)) + assert(accum.isDefined) + assertEquals(accum.get.value, value) + val accumUpdate = tm.accumulatorUpdates().find(_.name == Some(metricName)) + assert(accumUpdate.isDefined) + assert(accumUpdate.get.value === None) + assertEquals(accumUpdate.get.update, Some(value)) + } + + /** + * Assert that two lists of accumulator updates are equal. + * Note: this does NOT check accumulator ID equality. + */ + def assertUpdatesEquals( + updates1: Seq[AccumulableInfo], + updates2: Seq[AccumulableInfo]): Unit = { + assert(updates1.size === updates2.size) + updates1.zip(updates2).foreach { case (info1, info2) => + // do not assert ID equals here + assert(info1.name === info2.name) + assert(info1.update === info2.update) + assert(info1.value === info2.value) + assert(info1.internal === info2.internal) + assert(info1.countFailedValues === info2.countFailedValues) + } + } + + /** + * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * info as an accumulator update. + */ + def makeInfo(a: Accumulable[_, _]): AccumulableInfo = { + new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues) + } + } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 0e60cc8e7787..2b5e4b80e96a 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -31,7 +31,6 @@ object MemoryTestingUtils { taskAttemptId = 0, attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = env.metricsSystem, - internalAccumulators = Seq.empty) + metricsSystem = env.metricsSystem) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 370a284d2950..d9c71ec2eae7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -96,8 +95,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite - extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -111,8 +109,10 @@ class DAGSchedulerSuite override def schedulingMode: SchedulingMode = SchedulingMode.NONE override def start() = {} override def stop() = {} - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -189,7 +189,8 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception): Unit = { failure = exception } } - before { + override def beforeEach(): Unit = { + super.beforeEach() sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() @@ -202,17 +203,21 @@ class DAGSchedulerSuite results.clear() mapOutputTracker = new MapOutputTrackerMaster(conf) scheduler = new DAGScheduler( - sc, - taskScheduler, - sc.listenerBus, - mapOutputTracker, - blockManagerMaster, - sc.env) + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } - after { - scheduler.stop() + override def afterEach(): Unit = { + try { + scheduler.stop() + } finally { + super.afterEach() + } } override def afterAll() { @@ -242,26 +247,31 @@ class DAGSchedulerSuite * directly through CompletionEvents. */ private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => - it.next.asInstanceOf[Tuple2[_, _]]._1 + it.next.asInstanceOf[Tuple2[_, _]]._1 /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent( - taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2)) } } } - private def completeWithAccumulator(accumId: Long, taskSet: TaskSet, - results: Seq[(TaskEndReason, Any)]) { + private def completeWithAccumulator( + accumId: Long, + taskSet: TaskSet, + results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent( + taskSet.tasks(i), + result._1, + result._2, + Seq(new AccumulableInfo( + accumId, Some(""), Some(1), None, internal = false, countFailedValues = false)))) } } } @@ -338,9 +348,12 @@ class DAGSchedulerSuite } test("equals and hashCode AccumulableInfo") { - val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true) - val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) - val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + val accInfo1 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = true, countFailedValues = false) + val accInfo2 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false) + val accInfo3 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false) assert(accInfo1 !== accInfo2) assert(accInfo2 === accInfo3) assert(accInfo2.hashCode() === accInfo3.hashCode()) @@ -464,7 +477,7 @@ class DAGSchedulerSuite override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, TaskMetrics)], + accumUpdates: Array[(Long, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None @@ -499,8 +512,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) @@ -515,12 +528,12 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( - (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -829,23 +842,17 @@ class DAGSchedulerSuite HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -882,12 +889,9 @@ class DAGSchedulerSuite HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) @@ -900,12 +904,9 @@ class DAGSchedulerSuite assert(countSubmittedMapStageAttempts() === 2) // The second ResultTask fails, with a fetch failure for the output from the second mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(1), FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Another ResubmitFailedStages event should not result in another attempt for the map @@ -920,11 +921,11 @@ class DAGSchedulerSuite } /** - * This tests the case where a late FetchFailed comes in after the map stage has finished getting - * retried and a new reduce stage starts running. - */ + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ test("extremely late fetch failures don't cause multiple concurrent attempts for " + - "the same stage") { + "the same stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId @@ -952,12 +953,9 @@ class DAGSchedulerSuite assert(countSubmittedReduceStageAttempts() === 1) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Trigger resubmission of the failed map stage and finish the re-started map task. @@ -971,12 +969,9 @@ class DAGSchedulerSuite assert(countSubmittedReduceStageAttempts() === 2) // A late FetchFailed arrives from the second task in the original reduce stage. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(1), FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because @@ -1007,48 +1002,36 @@ class DAGSchedulerSuite assert(shuffleStage.numAvailableOutputs === 0) // should be ignored for being too old - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 0) // should work because it's a non-failed host (so the available map outputs will increase) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostB", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostB", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 1) // should be ignored for being too old - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 1) // should work because it's a new epoch, which will increase the number of available map // outputs, and also finish the stage taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(1), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1140,12 +1123,9 @@ class DAGSchedulerSuite // then one executor dies, and a task fails in stage 1 runEvent(ExecutorLost("exec-hostA")) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), - null, - null, - createFakeTaskInfo(), null)) // so we resubmit stage 0, which completes happily @@ -1155,13 +1135,10 @@ class DAGSchedulerSuite assert(stage0Resubmit.stageAttemptId === 1) val task = stage0Resubmit.tasks(0) assert(task.partitionId === 2) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( task, Success, - makeMapStatus("hostC", shuffleMapRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostC", shuffleMapRdd.partitions.length))) // now here is where things get tricky : we will now have a task set representing // the second attempt for stage 1, but we *also* have some tasks for the first attempt for @@ -1174,28 +1151,19 @@ class DAGSchedulerSuite // we'll have some tasks finish from the first attempt, and some finish from the second attempt, // so that we actually have all stage outputs, though no attempt has completed all its // tasks - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(3).tasks(0), Success, - makeMapStatus("hostC", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) - runEvent(CompletionEvent( + makeMapStatus("hostC", reduceRdd.partitions.length))) + runEvent(makeCompletionEvent( taskSets(3).tasks(1), Success, - makeMapStatus("hostC", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostC", reduceRdd.partitions.length))) // late task finish from the first attempt - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(2), Success, - makeMapStatus("hostB", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostB", reduceRdd.partitions.length))) // What should happen now is that we submit stage 2. However, we might not see an error // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But @@ -1242,21 +1210,21 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) // complete some of the tasks from the first stage, on one host - runEvent(CompletionEvent( - taskSets(0).tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) - runEvent(CompletionEvent( - taskSets(0).tasks(1), Success, - makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.length))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + Success, + makeMapStatus("hostA", reduceRdd.partitions.length))) // now that host goes down runEvent(ExecutorLost("exec-hostA")) // so we resubmit those tasks - runEvent(CompletionEvent( - taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) - runEvent(CompletionEvent( - taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null)) + runEvent(makeCompletionEvent(taskSets(0).tasks(1), Resubmitted, null)) // now complete everything on a different host complete(taskSets(0), Seq( @@ -1449,12 +1417,12 @@ class DAGSchedulerSuite // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -1469,15 +1437,15 @@ class DAGSchedulerSuite submit(finalRdd, Array(0)) // have the first stage complete normally complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) // have the second stage complete normally complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -1500,15 +1468,15 @@ class DAGSchedulerSuite cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) // complete stage 0 complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) // complete stage 1 complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. @@ -1606,6 +1574,28 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("accumulators are updated on exception failures") { + val acc1 = sc.accumulator(0L, "ingenieur") + val acc2 = sc.accumulator(0L, "boulanger") + val acc3 = sc.accumulator(0L, "agriculteur") + assert(Accumulators.get(acc1.id).isDefined) + assert(Accumulators.get(acc2.id).isDefined) + assert(Accumulators.get(acc3.id).isDefined) + val accInfo1 = new AccumulableInfo( + acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false) + val accInfo2 = new AccumulableInfo( + acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false) + val accInfo3 = new AccumulableInfo( + acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false) + val accumUpdates = Seq(accInfo1, accInfo2, accInfo3) + val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates) + submit(new MyRDD(sc, 1, Nil), Array(0)) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(Accumulators.get(acc1.id).get.value === 15L) + assert(Accumulators.get(acc2.id).get.value === 13L) + assert(Accumulators.get(acc3.id).get.value === 18L) + } + test("reduce tasks should be placed locally with map output") { // Create an shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) @@ -1614,9 +1604,9 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)))) + (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -1884,8 +1874,7 @@ class DAGSchedulerSuite submitMapStage(shuffleDep) val oldTaskSet = taskSets(0) - runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet // Pretend host A was lost @@ -1895,23 +1884,19 @@ class DAGSchedulerSuite assert(newEpoch > oldEpoch) // Suppose we also get a completed event from task 1 on the same host; this should be ignored - runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet // A completion from another task should work because it's a non-failed host - runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet // Now complete tasks in the second task set val newTaskSet = taskSets(1) assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA - runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet - runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) assert(results.size === 1) // Map stage job should now finally be complete assertDataStructuresEmpty() @@ -1962,5 +1947,21 @@ class DAGSchedulerSuite info } -} + private def makeCompletionEvent( + task: Task[_], + reason: TaskEndReason, + result: Any, + extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], + taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { + val accumUpdates = reason match { + case Success => + task.initialAccumulators.map { a => + new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues) + } + case ef: ExceptionFailure => ef.accumUpdates + case _ => Seq.empty[AccumulableInfo] + } + CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 761e82e6cf1c..35215c15ea80 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. @@ -131,7 +131,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(sc.eventLogger.isDefined) val originalEvents = sc.eventLogger.get.loggedEvents val replayedEvents = eventMonster.loggedEvents - originalEvents.zip(replayedEvents).foreach { case (e1, e2) => assert(e1 === e2) } + originalEvents.zip(replayedEvents).foreach { case (e1, e2) => + // Don't compare the JSON here because accumulators in StageInfo may be out of order + JsonProtocolSuite.assertEquals( + JsonProtocol.sparkEventFromJson(e1), JsonProtocol.sparkEventFromJson(e2)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index e5ec44a9f3b6..b3bb86db10a3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -22,6 +22,8 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.executor.TaskMetricsSuite +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD @@ -57,8 +59,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) - val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) + val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) intercept[RuntimeException] { task.run(0, 0, null) } @@ -97,6 +98,57 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }.collect() assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + + test("accumulators are updated on exception failures") { + // This means use 1 core and 4 max task failures + sc = new SparkContext("local[1,4]", "test") + val param = AccumulatorParam.LongAccumulatorParam + // Create 2 accumulators, one that counts failed values and another that doesn't + val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) + val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) + // Fail first 3 attempts of every task. This means each task should be run 4 times. + sc.parallelize(1 to 10, 10).map { i => + acc1 += 1 + acc2 += 1 + if (TaskContext.get.attemptNumber() <= 2) { + throw new Exception("you did something wrong") + } else { + 0 + } + }.count() + // The one that counts failed values should be 4x the one that didn't, + // since we ran each task 4 times + assert(Accumulators.get(acc1.id).get.value === 40L) + assert(Accumulators.get(acc2.id).get.value === 10L) + } + + test("failed tasks collect only accumulators whose values count during failures") { + sc = new SparkContext("local", "test") + val param = AccumulatorParam.LongAccumulatorParam + val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) + val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) + val initialAccums = InternalAccumulator.create() + // Create a dummy task. We won't end up running this; we just want to collect + // accumulator updates from it. + val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) { + context = new TaskContextImpl(0, 0, 0L, 0, + new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), + SparkEnv.get.metricsSystem, + initialAccums) + context.taskMetrics.registerAccumulator(acc1) + context.taskMetrics.registerAccumulator(acc2) + override def runTask(tc: TaskContext): Int = 0 + } + // First, simulate task success. This should give us all the accumulators. + val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false) + val accumUpdates2 = (initialAccums ++ Seq(acc1, acc2)).map(TaskMetricsSuite.makeInfo) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2) + // Now, simulate task failures. This should give us only the accums that count failed values. + val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true) + val accumUpdates4 = (initialAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index cc2557c2f1df..b5385c11a926 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -21,10 +21,15 @@ import java.io.File import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal +import com.google.common.util.concurrent.MoreExecutors +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, anyLong} +import org.mockito.Mockito.{spy, times, verify} import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ @@ -33,13 +38,14 @@ import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} + /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. * * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) +private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false @@ -72,6 +78,31 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule } } + +/** + * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors + * _before_ modifying the results in any way. + */ +private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl) + extends TaskResultGetter(env, scheduler) { + + // Use the current thread so we can access its results synchronously + protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor() + + // DirectTaskResults that we receive from the executors + private val _taskResults = new ArrayBuffer[DirectTaskResult[_]] + + def taskResults: Seq[DirectTaskResult[_]] = _taskResults + + override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = { + // work on a copy since the super class still needs to use the buffer + val newBuffer = data.duplicate() + _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer) + super.enqueueSuccessfulTask(tsm, tid, data) + } +} + + /** * Tests related to handling task results (both direct and indirect). */ @@ -182,5 +213,39 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local Thread.currentThread.setContextClassLoader(originalClassLoader) } } + + test("task result size is set on the driver, not the executors") { + import InternalAccumulator._ + + // Set up custom TaskResultGetter and TaskSchedulerImpl spy + sc = new SparkContext("local", "test", conf) + val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val spyScheduler = spy(scheduler) + val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler) + val newDAGScheduler = new DAGScheduler(sc, spyScheduler) + scheduler.taskResultGetter = resultGetter + sc.dagScheduler = newDAGScheduler + sc.taskScheduler = spyScheduler + sc.taskScheduler.setDAGScheduler(newDAGScheduler) + + // Just run 1 task and capture the corresponding DirectTaskResult + sc.parallelize(1 to 1, 1).count() + val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]]) + verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture()) + + // When a task finishes, the executor sends a serialized DirectTaskResult to the driver + // without setting the result size so as to avoid serializing the result again. Instead, + // the result size is set later in TaskResultGetter on the driver before passing the + // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult + // before and after the result size is set. + assert(resultGetter.taskResults.size === 1) + val resBefore = resultGetter.taskResults.head + val resAfter = captor.getValue + val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update) + val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update) + assert(resSizeBefore.exists(_ == 0L)) + assert(resSizeAfter.exists(_.toString.toLong > 0L)) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ecc18fc6e15b..a2e74365641a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) @@ -38,9 +37,8 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo) { taskScheduler.endedTasks(taskInfo.index) = reason } @@ -167,14 +165,17 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => + new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) + } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have - var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // Tell it the task has finished - manager.handleSuccessfulTask(0, createTaskResult(0)) + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) assert(sched.endedTasks(0) === Success) assert(sched.finishedManagers.contains(manager)) } @@ -184,10 +185,15 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => + task.initialAccumulators.map { a => + new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) + } + } // First three offers should all find tasks for (i <- 0 until 3) { - var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) val task = taskOption.get assert(task.executorId === "exec1") @@ -198,14 +204,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None) // Finish the first two tasks - manager.handleSuccessfulTask(0, createTaskResult(0)) - manager.handleSuccessfulTask(1, createTaskResult(1)) + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdatesByTask(0))) + manager.handleSuccessfulTask(1, createTaskResult(1, accumUpdatesByTask(1))) assert(sched.endedTasks(0) === Success) assert(sched.endedTasks(1) === Success) assert(!sched.finishedManagers.contains(manager)) // Finish the last task - manager.handleSuccessfulTask(2, createTaskResult(2)) + manager.handleSuccessfulTask(2, createTaskResult(2, accumUpdatesByTask(2))) assert(sched.endedTasks(2) === Success) assert(sched.finishedManagers.contains(manager)) } @@ -620,7 +626,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // multiple 1k result val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() - assert(10 === r.size ) + assert(10 === r.size) // single 10M result val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} @@ -761,7 +767,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Regression test for SPARK-2931 sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, - ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -786,8 +792,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) } - def createTaskResult(id: Int): DirectTaskResult[Int] = { + private def createTaskResult( + id: Int, + accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() - new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) + new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 86699e7f5695..b83ffa3282e4 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { + private val peakExecutionMemory = 10 + test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") @@ -52,7 +54,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values - assert(html.contains("10.0 b" * 5)) + assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } /** @@ -79,14 +81,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) - val peakExecutionMemory = 10 - taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, - Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markSuccessful() + val taskMetrics = TaskMetrics.empty + taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 607617cbe91c..18a16a25bfac 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None, None), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0", true, Some("Induced failure")), @@ -269,20 +269,22 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = new TaskMetrics() - taskMetrics.setExecutorRunTime(base + 4) - taskMetrics.incDiskBytesSpilled(base + 5) - taskMetrics.incMemoryBytesSpilled(base + 6) + val accums = InternalAccumulator.create() + accums.foreach(Accumulators.register) + val taskMetrics = new TaskMetrics(accums) val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() + val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() + val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) + val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) shuffleReadMetrics.incRemoteBytesRead(base + 1) shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) taskMetrics.mergeShuffleReadMetrics() - val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() shuffleWriteMetrics.incBytesWritten(base + 3) - val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) - inputMetrics.incBytesRead(base + 7) - val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) + taskMetrics.setExecutorRunTime(base + 4) + taskMetrics.incDiskBytesSpilled(base + 5) + taskMetrics.incMemoryBytesSpilled(base + 6) + inputMetrics.setBytesRead(base + 7) outputMetrics.setBytesWritten(base + 8) taskMetrics } @@ -300,9 +302,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, 0, makeTaskMetrics(0)), - (1235L, 0, 0, makeTaskMetrics(100)), - (1236L, 1, 0, makeTaskMetrics(200))))) + (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()), + (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()), + (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates())))) var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e5ca2de4ad53..57021d1d3d52 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -22,6 +22,10 @@ import java.util.Properties import scala.collection.Map import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JArray, JInt, JString, JValue} +import org.json4s.JsonDSL._ +import org.scalatest.Assertions +import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ @@ -32,12 +36,7 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage._ class JsonProtocolSuite extends SparkFunSuite { - - val jobSubmissionTime = 1421191042750L - val jobCompletionTime = 1421191296660L - - val executorAddedTime = 1421458410000L - val executorRemovedTime = 1421458922000L + import JsonProtocolSuite._ test("SparkListenerEvent") { val stageSubmitted = @@ -82,9 +81,13 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") - val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( - (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, - hasHadoopInput = true, hasOutput = true)))) + val executorMetricsUpdate = { + // Use custom accum ID for determinism + val accumUpdates = + makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true) + .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) } + SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) + } testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -142,7 +145,7 @@ class JsonProtocolSuite extends SparkFunSuite { "Some exception") val fetchMetadataFailed = new MetadataFetchFailedException(17, 19, "metadata Fetch failed exception").toTaskEndReason - val exceptionFailure = new ExceptionFailure(exception, None) + val exceptionFailure = new ExceptionFailure(exception, Seq.empty[AccumulableInfo]) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -166,9 +169,8 @@ class JsonProtocolSuite extends SparkFunSuite { | Backward compatibility tests | * ============================== */ - test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, - None, None) + test("ExceptionFailure backward compatibility: full stack trace") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) @@ -273,14 +275,13 @@ class JsonProtocolSuite extends SparkFunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } - test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { - // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + test("ShuffleReadMetrics: Local bytes read backwards compatibility") { + // Metrics about local shuffle bytes read were added in 1.3.1. val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = false, hasRecords = false) assert(metrics.shuffleReadMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } - .removeField { case (field, _) => field == "Local Read Time" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) } @@ -371,22 +372,76 @@ class JsonProtocolSuite extends SparkFunSuite { } test("AccumulableInfo backward compatibility") { - // "Internal" property of AccumulableInfo were added after 1.5.1. - val accumulableInfo = makeAccumulableInfo(1) + // "Internal" property of AccumulableInfo was added in 1.5.1 + val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true) val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) .removeField({ _._1 == "Internal" }) val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) - assert(false === oldInfo.internal) + assert(!oldInfo.internal) + // "Count Failed Values" property of AccumulableInfo was added in 2.0.0 + val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo) + .removeField({ _._1 == "Count Failed Values" }) + val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2) + assert(!oldInfo2.countFailedValues) + } + + test("ExceptionFailure backward compatibility: accumulator updates") { + // "Task Metrics" was replaced with "Accumulator Updates" in 2.0.0. For older event logs, + // we should still be able to fallback to constructing the accumulator updates from the + // "Task Metrics" field, if it exists. + val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true) + val tmJson = JsonProtocol.taskMetricsToJson(tm) + val accumUpdates = tm.accumulatorUpdates() + val exception = new SparkException("sentimental") + val exceptionFailure = new ExceptionFailure(exception, accumUpdates) + val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure) + val tmFieldJson: JValue = "Task Metrics" -> tmJson + val oldExceptionFailureJson: JValue = + exceptionFailureJson.removeField { _._1 == "Accumulator Updates" }.merge(tmFieldJson) + val oldExceptionFailure = + JsonProtocol.taskEndReasonFromJson(oldExceptionFailureJson).asInstanceOf[ExceptionFailure] + assert(exceptionFailure.className === oldExceptionFailure.className) + assert(exceptionFailure.description === oldExceptionFailure.description) + assertSeqEquals[StackTraceElement]( + exceptionFailure.stackTrace, oldExceptionFailure.stackTrace, assertStackTraceElementEquals) + assert(exceptionFailure.fullStackTrace === oldExceptionFailure.fullStackTrace) + assertSeqEquals[AccumulableInfo]( + exceptionFailure.accumUpdates, oldExceptionFailure.accumUpdates, (x, y) => x == y) } - /** -------------------------- * - | Helper test running methods | - * --------------------------- */ + test("AccumulableInfo value de/serialization") { + import InternalAccumulator._ + val blocks = Seq[(BlockId, BlockStatus)]( + (TestBlockId("meebo"), BlockStatus(StorageLevel.MEMORY_ONLY, 1L, 2L)), + (TestBlockId("feebo"), BlockStatus(StorageLevel.DISK_ONLY, 3L, 4L))) + val blocksJson = JArray(blocks.toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> JsonProtocol.blockStatusToJson(status)) + }) + testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) + testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) + testAccumValue(Some(input.READ_METHOD), "aka", JString("aka")) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + // For anything else, we just cast the value to a string + testAccumValue(Some("anything"), blocks, JString(blocks.toString)) + testAccumValue(Some("anything"), 123, JString("123")) + } + +} + + +private[spark] object JsonProtocolSuite extends Assertions { + import InternalAccumulator._ + + private val jobSubmissionTime = 1421191042750L + private val jobCompletionTime = 1421191296660L + private val executorAddedTime = 1421458410000L + private val executorRemovedTime = 1421458922000L private def testEvent(event: SparkListenerEvent, jsonString: String) { val actualJsonString = compact(render(JsonProtocol.sparkEventToJson(event))) val newEvent = JsonProtocol.sparkEventFromJson(parse(actualJsonString)) - assertJsonStringEquals(jsonString, actualJsonString) + assertJsonStringEquals(jsonString, actualJsonString, event.getClass.getSimpleName) assertEquals(event, newEvent) } @@ -440,11 +495,19 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } + private def testAccumValue(name: Option[String], value: Any, expectedJson: JValue): Unit = { + val json = JsonProtocol.accumValueToJson(name, value) + assert(json === expectedJson) + val newValue = JsonProtocol.accumValueFromJson(name, json) + val expectedValue = if (name.exists(_.startsWith(METRICS_PREFIX))) value else value.toString + assert(newValue === expectedValue) + } + /** -------------------------------- * | Util methods for comparing events | - * --------------------------------- */ + * --------------------------------- */ - private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { + private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { (event1, event2) match { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => assert(e1.properties === e2.properties) @@ -478,14 +541,17 @@ class JsonProtocolSuite extends SparkFunSuite { assert(e1.executorId === e1.executorId) case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => assert(e1.execId === e2.execId) - assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { - val (taskId1, stageId1, stageAttemptId1, metrics1) = a - val (taskId2, stageId2, stageAttemptId2, metrics2) = b - assert(taskId1 === taskId2) - assert(stageId1 === stageId2) - assert(stageAttemptId1 === stageAttemptId2) - assertEquals(metrics1, metrics2) - }) + assertSeqEquals[(Long, Int, Int, Seq[AccumulableInfo])]( + e1.accumUpdates, + e2.accumUpdates, + (a, b) => { + val (taskId1, stageId1, stageAttemptId1, updates1) = a + val (taskId2, stageId2, stageAttemptId2, updates2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b)) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -544,7 +610,6 @@ class JsonProtocolSuite extends SparkFunSuite { } private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { - assert(metrics1.hostname === metrics2.hostname) assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) assert(metrics1.resultSize === metrics2.resultSize) assert(metrics1.jvmGCTime === metrics2.jvmGCTime) @@ -601,7 +666,7 @@ class JsonProtocolSuite extends SparkFunSuite { assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) assert(r1.fullStackTrace === r2.fullStackTrace) - assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) + assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), @@ -637,10 +702,16 @@ class JsonProtocolSuite extends SparkFunSuite { assertStackTraceElementEquals) } - private def assertJsonStringEquals(json1: String, json2: String) { + private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - assert(formatJsonString(json1) === formatJsonString(json2), - s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}") + if (formatJsonString(expected) != formatJsonString(actual)) { + // scalastyle:off + // This prints something useful if the JSON strings don't match + println("=== EXPECTED ===\n" + pretty(parse(expected)) + "\n") + println("=== ACTUAL ===\n" + pretty(parse(actual)) + "\n") + // scalastyle:on + throw new TestFailedException(s"$metadata JSON did not equal", 1) + } } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -699,7 +770,7 @@ class JsonProtocolSuite extends SparkFunSuite { /** ----------------------------------- * | Util methods for constructing events | - * ------------------------------------ */ + * ------------------------------------ */ private val properties = { val p = new Properties @@ -746,8 +817,12 @@ class JsonProtocolSuite extends SparkFunSuite { taskInfo } - private def makeAccumulableInfo(id: Int, internal: Boolean = false): AccumulableInfo = - AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id, internal) + private def makeAccumulableInfo( + id: Int, + internal: Boolean = false, + countFailedValues: Boolean = false): AccumulableInfo = + new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), + internal, countFailedValues) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is @@ -764,7 +839,6 @@ class JsonProtocolSuite extends SparkFunSuite { hasOutput: Boolean, hasRecords: Boolean = true) = { val t = new TaskMetrics - t.setHostname("localhost") t.setExecutorDeserializeTime(a) t.setExecutorRunTime(b) t.setResultSize(c) @@ -774,7 +848,7 @@ class JsonProtocolSuite extends SparkFunSuite { if (hasHadoopInput) { val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) - inputMetrics.incBytesRead(d + e + f) + inputMetrics.setBytesRead(d + e + f) inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) } else { val sr = t.registerTempShuffleReadMetrics() @@ -794,7 +868,7 @@ class JsonProtocolSuite extends SparkFunSuite { val sw = t.registerShuffleWriteMetrics() sw.incBytesWritten(a + b + c) sw.incWriteTime(b + c + d) - sw.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) + sw.incRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) } // Make at most 6 blocks t.setUpdatedBlockStatuses((1 to (e % 5 + 1)).map { i => @@ -826,14 +900,16 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -881,14 +957,16 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | } @@ -919,21 +997,24 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | } @@ -962,21 +1043,24 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | } @@ -1011,26 +1095,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1044,7 +1130,7 @@ class JsonProtocolSuite extends SparkFunSuite { | "Fetch Wait Time": 900, | "Remote Bytes Read": 1000, | "Local Bytes Read": 1100, - | "Total Records Read" : 10 + | "Total Records Read": 10 | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, @@ -1098,26 +1184,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1182,26 +1270,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1273,17 +1363,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1331,17 +1423,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1405,17 +1499,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1495,17 +1591,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | } @@ -1657,51 +1755,208 @@ class JsonProtocolSuite extends SparkFunSuite { """ private val executorMetricsUpdateJsonString = - s""" - |{ - | "Event": "SparkListenerExecutorMetricsUpdate", - | "Executor ID": "exec3", - | "Metrics Updated": [ - | { - | "Task ID": 1, - | "Stage ID": 2, - | "Stage Attempt ID": 3, - | "Task Metrics": { - | "Host Name": "localhost", - | "Executor Deserialize Time": 300, - | "Executor Run Time": 400, - | "Result Size": 500, - | "JVM GC Time": 600, - | "Result Serialization Time": 700, - | "Memory Bytes Spilled": 800, - | "Disk Bytes Spilled": 0, - | "Input Metrics": { - | "Data Read Method": "Hadoop", - | "Bytes Read": 2100, - | "Records Read": 21 - | }, - | "Output Metrics": { - | "Data Write Method": "Hadoop", - | "Bytes Written": 1200, - | "Records Written": 12 - | }, - | "Updated Blocks": [ - | { - | "Block ID": "rdd_0_0", - | "Status": { - | "Storage Level": { - | "Use Disk": true, - | "Use Memory": true, - | "Deserialized": false, - | "Replication": 2 - | }, - | "Memory Size": 0, - | "Disk Size": 0 - | } - | } - | ] - | } - | }] - |} - """.stripMargin + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Accumulator Updates": [ + | { + | "ID": 0, + | "Name": "$EXECUTOR_DESERIALIZE_TIME", + | "Update": 300, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 1, + | "Name": "$EXECUTOR_RUN_TIME", + | "Update": 400, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 2, + | "Name": "$RESULT_SIZE", + | "Update": 500, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 3, + | "Name": "$JVM_GC_TIME", + | "Update": 600, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 4, + | "Name": "$RESULT_SERIALIZATION_TIME", + | "Update": 700, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 5, + | "Name": "$MEMORY_BYTES_SPILLED", + | "Update": 800, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 6, + | "Name": "$DISK_BYTES_SPILLED", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 7, + | "Name": "$PEAK_EXECUTION_MEMORY", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 8, + | "Name": "$UPDATED_BLOCK_STATUSES", + | "Update": [ + | { + | "BlockID": "rdd_0_0", + | "Status": { + | "StorageLevel": { + | "UseDisk": true, + | "UseMemory": true, + | "Deserialized": false, + | "Replication": 2 + | }, + | "MemorySize": 0, + | "DiskSize": 0 + | } + | } + | ], + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 9, + | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 10, + | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 11, + | "Name": "${shuffleRead.REMOTE_BYTES_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 12, + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 13, + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 14, + | "Name": "${shuffleRead.RECORDS_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 15, + | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 16, + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 17, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 18, + | "Name": "${input.READ_METHOD}", + | "Update": "Hadoop", + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 19, + | "Name": "${input.BYTES_READ}", + | "Update": 2100, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 20, + | "Name": "${input.RECORDS_READ}", + | "Update": 21, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 21, + | "Name": "${output.WRITE_METHOD}", + | "Update": "Hadoop", + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 22, + | "Name": "${output.BYTES_WRITTEN}", + | "Update": 1200, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 23, + | "Name": "${output.RECORDS_WRITTEN}", + | "Update": 12, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 24, + | "Name": "$TEST_ACCUM", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | } + | ] + | } + | ] + |} + """.stripMargin } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fc7dc2181de8..968a2903f301 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -175,6 +175,15 @@ object MimaExcludes { ) ++ Seq( // SPARK-12510 Refactor ActorReceiver to support Java ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") + ) ++ Seq( + // SPARK-12895 Implement TaskMetrics using accumulators + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") + ) ++ Seq( + // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulator.this") ) ++ Seq( // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 73dc8cb98447..75cb6d1137c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -79,17 +79,17 @@ case class Sort( sorter.setTestSpillFrequency(testSpillFrequency) } + val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. - val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + val spillSizeBefore = metrics.memoryBytesSpilled val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) dataSize += sorter.getPeakMemoryUsage - spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 41799c596b6d..001e9c306ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -418,10 +418,10 @@ class TungstenAggregationIterator( val mapMemory = hashMap.getPeakMemoryUsedBytes val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) val peakMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() dataSize += peakMemory - spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(peakMemory) } numOutputRows += 1 res diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 8222b84d33e3..edd87c2d8ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -136,14 +136,17 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) val format = inputFormatClass.newInstance format match { @@ -208,6 +211,9 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } reader.getCurrentValue } @@ -228,8 +234,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index c9ea579b5e80..04640711d99d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -111,8 +111,7 @@ case class BroadcastHashJoin( val hashedRelation = broadcastRelation.value hashedRelation match { case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) case _ => } hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 6c7fa2eee5bf..db8edd169dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -119,8 +119,7 @@ case class BroadcastHashOuterJoin( hashTable match { case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) case _ => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 004407b2e692..8929dc3af191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -66,8 +66,7 @@ case class BroadcastLeftSemiJoinHash( val hashedRelation = broadcastedRelation.value hashedRelation match { case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) case _ => } hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 52735c9d7f8c..950dc7816241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.metric -import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} +import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} import org.apache.spark.util.Utils /** @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( name: String, val param: SQLMetricParam[R, T]) - extends Accumulable[R, T](param.zero, param, Some(name), true) { + extends Accumulable[R, T](param.zero, param, Some(name), internal = true) { def reset(): Unit = { this.value = param.zero @@ -131,6 +131,8 @@ private[sql] object SQLMetrics { name: String, param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) + // This is an internal accumulator so we need to register it explicitly. + Accumulators.register(acc) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 83c64f755f90..544606f1168b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -139,9 +139,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), - finishTask = false) + for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false) } } @@ -177,7 +176,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskId: Long, stageId: Int, stageAttemptID: Int, - accumulatorUpdates: Map[Long, Any], + accumulatorUpdates: Seq[AccumulableInfo], finishTask: Boolean): Unit = { _stageIdToStageMetrics.get(stageId) match { @@ -289,8 +288,10 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi for (stageId <- executionUIData.stages; stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; - accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { - accumulatorUpdate + accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { + assert(accumulatorUpdate.update.isDefined, s"accumulator update from " + + s"task did not have a partial value: ${accumulatorUpdate.name}") + (accumulatorUpdate.id, accumulatorUpdate.update.get) } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => @@ -328,9 +329,10 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { acc => - (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) - }.toMap, + taskEnd.taskInfo.accumulables.map { a => + val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L)) + a.copy(update = Some(newValue)) + }, finishTask = true) } @@ -406,4 +408,4 @@ private[ui] class SQLStageMetrics( private[ui] class SQLTaskMetrics( val attemptId: Long, // TODO not used yet var finished: Boolean, - var accumulatorUpdates: Map[Long, Any]) + var accumulatorUpdates: Seq[AccumulableInfo]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 47308966e92c..10ccd4b8f60d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1648,7 +1648,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("external sorting updates peak execution memory") { AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala index 9575d26fd123..273937fa8ce9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -49,8 +49,7 @@ case class ReferenceSort( val context = TaskContext.get() context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 9c258cb31f46..c7df8b51e2f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -71,8 +71,7 @@ class UnsafeFixedWidthAggregationMapSuite taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null, - internalAccumulators = Seq.empty)) + metricsSystem = null)) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 8a95359d9de2..e03bd6a3e7d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -117,8 +117,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null, - internalAccumulators = Seq.empty)) + metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 647a7e9a4e19..86c2c25c2c7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution.columnar +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + +class PartitionBatchPruningSuite + extends SparkFunSuite + with BeforeAndAfterEach + with SharedSQLContext { + import testImplicits._ private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize @@ -32,30 +39,41 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - - val pruningData = sparkContext.makeRDD((1 to 100).map { key => - val string = if (((key - 1) / 10) % 2 == 0) null else key.toString - TestData(key, string) - }, 5).toDF() - pruningData.registerTempTable("pruningData") - // Enable in-memory partition pruning sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { try { sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - sqlContext.uncacheTable("pruningData") } finally { super.afterAll() } } + override protected def beforeEach(): Unit = { + super.beforeEach() + // This creates accumulators, which get cleaned up after every single test, + // so we need to do this before every test. + val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val string = if (((key - 1) / 10) % 2 == 0) null else key.toString + TestData(key, string) + }, 5).toDF() + pruningData.registerTempTable("pruningData") + sqlContext.cacheTable("pruningData") + } + + override protected def afterEach(): Unit = { + try { + sqlContext.uncacheTable("pruningData") + } finally { + super.afterEach() + } + } + // Comparisons checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1)) checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 81a159d542c6..2c408c887847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import org.mockito.Mockito.{mock, when} + import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -67,9 +69,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { ) private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { - val metrics = new TaskMetrics - metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) - metrics.updateAccumulators() + val metrics = mock(classOf[TaskMetrics]) + when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) => + new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)), + value = None, internal = true, countFailedValues = true) + }.toSeq) metrics } @@ -114,17 +118,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(listener.getExecutionMetrics(0).isEmpty) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) @@ -133,9 +137,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) @@ -173,9 +177,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b46b0d2f6040..9a24a2487a25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -140,7 +140,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) assert(peakMemoryAccumulator.size == 1) - peakMemoryAccumulator.head._2.value.toLong + peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long] } assert(sparkListener.getCompletedStageInfos.length == 2) From 32f741115bda5d7d7dbfcd9fe827ecbea7303ffa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 Jan 2016 13:27:32 -0800 Subject: [PATCH 042/131] [SPARK-13021][CORE] Fail fast when custom RDDs violate RDD.partition's API contract Spark's `Partition` and `RDD.partitions` APIs have a contract which requires custom implementations of `RDD.partitions` to ensure that for all `x`, `rdd.partitions(x).index == x`; in other words, the `index` reported by a repartition needs to match its position in the partitions array. If a custom RDD implementation violates this contract, then Spark has the potential to become stuck in an infinite recomputation loop when recomputing a subset of an RDD's partitions, since the tasks that are actually run will not correspond to the missing output partitions that triggered the recomputation. Here's a link to a notebook which demonstrates this problem: https://rawgit.com/JoshRosen/e520fb9a64c1c97ec985/raw/5e8a5aa8d2a18910a1607f0aa4190104adda3424/Violating%2520RDD.partitions%2520contract.html In order to guard against this infinite loop behavior, this patch modifies Spark so that it fails fast and refuses to compute RDDs' whose `partitions` violate the API contract. Author: Josh Rosen Closes #10932 from JoshRosen/SPARK-13021. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 7 +++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9dad7944144d..be47172581b7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag]( /** * Implemented by subclasses to return the set of partitions in this RDD. This method will only * be called once, so it is safe to implement a time-consuming computation in it. + * + * The partitions in this array must satisfy the following property: + * `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }` */ protected def getPartitions: Array[Partition] @@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag]( checkpointRDD.map(_.partitions).getOrElse { if (partitions_ == null) { partitions_ = getPartitions + partitions_.zipWithIndex.foreach { case (partition, index) => + require(partition.index == index, + s"partitions($index).partition == ${partition.index}, but it should equal $index") + } } partitions_ } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ef2ed445005d..80347b800a7b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") { + class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) { + + override def compute(part: Partition, context: TaskContext): Iterator[T] = { + prev.compute(part, context) + } + + override protected def getPartitions: Array[Partition] = { + prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i` + } + } + val rdd = new BadRDD(sc.parallelize(1 to 100, 100)) + val e = intercept[IllegalArgumentException] { + rdd.partitions + } + assert(e.getMessage.contains("partitions")) + } + test("nested RDDs are not supported (SPARK-5063)") { val rdd: RDD[Int] = sc.parallelize(1 to 100) val rdd2: RDD[Int] = sc.parallelize(1 to 100) From 680afabe78b77e4e63e793236453d69567d24290 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 27 Jan 2016 13:29:09 -0800 Subject: [PATCH 043/131] [SPARK-12938][SQL] DataFrame API for Bloom filter This PR integrates Bloom filter from spark-sketch into DataFrame. This version resorts to RDD.aggregate for building the filter. A more performant UDAF version can be built in future follow-up PRs. This PR also add 2 specify `put` version(`putBinary` and `putLong`) into `BloomFilter`, which makes it easier to build a Bloom filter over a `DataFrame`. Author: Wenchen Fan Closes #10937 from cloud-fan/bloom-filter. --- .../apache/spark/util/sketch/BloomFilter.java | 34 ++++- .../spark/util/sketch/BloomFilterImpl.java | 141 ++++++++++++------ .../spark/util/sketch/CountMinSketchImpl.java | 47 +----- .../org/apache/spark/util/sketch/Utils.java | 48 ++++++ .../spark/sql/DataFrameStatFunctions.scala | 76 +++++++++- .../apache/spark/sql/JavaDataFrameSuite.java | 31 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 22 +++ 7 files changed, 306 insertions(+), 93 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index d392fb187ad6..81772fcea0ec 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -49,9 +49,9 @@ public enum Version { * {@code BloomFilter} binary format version 1 (all values written in big-endian order): *
    *
  • Version number, always 1 (32 bit)
  • + *
  • Number of hash functions (32 bit)
  • *
  • Total number of words of the underlying bit array (32 bit)
  • *
  • The words/longs (numWords * 64 bit)
  • - *
  • Number of hash functions (32 bit)
  • *
*/ V1(1); @@ -97,6 +97,21 @@ int getVersionNumber() { */ public abstract boolean put(Object item); + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. + */ + public abstract boolean putString(String str); + + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put long. + */ + public abstract boolean putLong(long l); + + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + */ + public abstract boolean putBinary(byte[] bytes); + /** * Determines whether a given bloom filter is compatible with this bloom filter. For two * bloom filters to be compatible, they must have the same bit size. @@ -121,6 +136,23 @@ int getVersionNumber() { */ public abstract boolean mightContain(Object item); + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 + * string. + */ + public abstract boolean mightContainString(String str); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + */ + public abstract boolean mightContainLong(long l); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte + * array. + */ + public abstract boolean mightContainBinary(byte[] bytes); + /** * Writes out this {@link BloomFilter} to an output stream in binary format. * It is the caller's responsibility to close the stream. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 1c08d07afaea..35107e0b389d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,10 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter { +public class BloomFilterImpl extends BloomFilter implements Serializable { - private final int numHashFunctions; - private final BitArray bits; + private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { this(new BitArray(numBits), numHashFunctions); @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) { this.numHashFunctions = numHashFunctions; } + private BloomFilterImpl() {} + @Override public boolean equals(Object other) { if (other == this) { @@ -63,55 +65,75 @@ public long bitSize() { return bits.bitSize(); } - private static long hashObjectToLong(Object item) { + @Override + public boolean put(Object item) { if (item instanceof String) { - try { - byte[] bytes = ((String) item).getBytes("utf-8"); - return hashBytesToLong(bytes); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("Only support utf-8 string", e); - } + return putString((String) item); + } else if (item instanceof byte[]) { + return putBinary((byte[]) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - int h1 = Murmur3_x86_32.hashLong(longValue, 0); - int h2 = Murmur3_x86_32.hashLong(longValue, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + return putLong(Utils.integralToLong(item)); } } - private static long hashBytesToLong(byte[] bytes) { + @Override + public boolean putString(String str) { + return putBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean putBinary(byte[] bytes) { int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; } @Override - public boolean put(Object item) { + public boolean mightContainString(String str) { + return mightContainBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean mightContainBinary(byte[] bytes) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } - // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash - // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy for long type, it hash the input long - // element with every i to produce n hash values. - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + @Override + public boolean putLong(long l) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); boolean bitsChanged = false; for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); @@ -125,12 +147,11 @@ public boolean put(Object item) { } @Override - public boolean mightContain(Object item) { - long bitSize = bits.bitSize(); - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + public boolean mightContainLong(long l) { + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); // Flip all the bits if it's negative (guaranteed positive number) @@ -144,6 +165,17 @@ public boolean mightContain(Object item) { return true; } + @Override + public boolean mightContain(Object item) { + if (item instanceof String) { + return mightContainString((String) item); + } else if (item instanceof byte[]) { + return mightContainBinary((byte[]) item); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + @Override public boolean isCompatible(BloomFilter other) { if (other == null) { @@ -191,11 +223,11 @@ public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); dos.writeInt(Version.V1.getVersionNumber()); - bits.writeTo(dos); dos.writeInt(numHashFunctions); + bits.writeTo(dos); } - public static BloomFilterImpl readFrom(InputStream in) throws IOException { + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); int version = dis.readInt(); @@ -203,6 +235,21 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { throw new IOException("Unexpected Bloom filter version number (" + version + ")"); } - return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + BloomFilterImpl filter = new BloomFilterImpl(); + filter.readFrom0(in); + return filter; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 8cc29e407630..e49ae22906c4 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private double eps; private double confidence; - private CountMinSketchImpl() { - } + private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; @@ -143,23 +142,7 @@ public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - addLong(longValue, count); + addLong(Utils.integralToLong(item), count); } } @@ -201,13 +184,7 @@ private int hash(long item, int count) { } private static int[] getHashBuckets(String key, int hashCount, int max) { - byte[] b; - try { - b = key.getBytes("UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException(e); - } - return getHashBuckets(b, hashCount, max); + return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); } private static int[] getHashBuckets(byte[] b, int hashCount, int max) { @@ -225,23 +202,7 @@ public long estimateCount(Object item) { if (item instanceof String) { return estimateCountForStringItem((String) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - return estimateCountForLongItem(longValue); + return estimateCountForLongItem(Utils.integralToLong(item)); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java new file mode 100644 index 000000000000..a6b33313035b --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.UnsupportedEncodingException; + +public class Utils { + public static byte[] getBytesFromUTF8String(String str) { + try { + return str.getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + throw new IllegalArgumentException("Only support utf-8 string", e); + } + } + + public static long integralToLong(Object i) { + long longValue; + + if (i instanceof Long) { + longValue = (Long) i; + } else if (i instanceof Integer) { + longValue = ((Integer) i).longValue(); + } else if (i instanceof Short) { + longValue = ((Short) i).longValue(); + } else if (i instanceof Byte) { + longValue = ((Byte) i).longValue(); + } else { + throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName()); + } + + return longValue; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 465b12bb59d1..b0b6995a2214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,9 +22,10 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * :: Experimental :: @@ -390,4 +391,75 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { } ) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) + } + + private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require(colType == StringType || colType.isInstanceOf[IntegralType], + s"Bloom filter only supports string type and integral types, but got $colType.") + + val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { + (filter, row) => + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + filter.putBinary(row.getUTF8String(0).getBytes) + filter + } else { + (filter, row) => + // TODO: specialize it. + filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) + filter + } + + singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9cf94e72d34e..0d4c128cb36d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -40,6 +40,7 @@ import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -300,6 +301,7 @@ public void pivot() { Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + @Test public void testGenericLoad() { DataFrame df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); @@ -347,4 +349,33 @@ public void testCountMinSketch() { Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); } + + @Test + public void testBloomFilter() { + DataFrame df = context.range(1000); + + BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); + assert (filter1.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter1.mightContain(i)); + } + + BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); + assert (filter2.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter2.mightContain(i * 3)); + } + + BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); + assert (filter3.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter3.mightContain(i)); + } + + BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); + assert (filter4.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter4.mightContain(i * 3)); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f3ea5a2860b..f01f126f7696 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -246,4 +246,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { .countMinSketch('id, depth = 10, width = 20, seed = 42) } } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = sqlContext.range(1000) + + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + + val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03) + assert(filter2.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(i => filter2.mightContain(i * 3))) + + val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter3.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter3.mightContain)) + + val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5) + assert(filter4.bitSize() == 64 * 5) + assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) + } } From ef96cd3c521c175878c38a1ed6eeeab0ed8346b5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 27 Jan 2016 13:45:00 -0800 Subject: [PATCH 044/131] [SPARK-12865][SPARK-12866][SQL] Migrate SparkSQLParser/ExtendedHiveQlParser commands to new Parser This PR moves all the functionality provided by the SparkSQLParser/ExtendedHiveQlParser to the new Parser hierarchy (SparkQl/HiveQl). This also improves the current SET command parsing: the current implementation swallows ```set role ...``` and ```set autocommit ...``` commands, this PR respects these commands (and passes them on to Hive). This PR and https://github.com/apache/spark/pull/10723 end the use of Parser-Combinator parsers for SQL parsing. As a result we can also remove the ```AbstractSQLParser``` in Catalyst. The PR is marked WIP as long as it doesn't pass all tests. cc rxin viirya winningsix (this touches https://github.com/apache/spark/pull/10144) Author: Herman van Hovell Closes #10905 from hvanhovell/SPARK-12866. --- .../sql/catalyst/parser/ExpressionParser.g | 12 +- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 5 + .../sql/catalyst/parser/SparkSqlParser.g | 62 +++++++-- .../spark/sql/catalyst/CatalystQl.scala | 22 ++++ .../spark/sql/catalyst/parser/ASTNode.scala | 14 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/SparkQl.scala | 31 +++++ .../spark/sql/execution/SparkSQLParser.scala | 124 ------------------ .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- .../spark/sql/hive/ExtendedHiveQlParser.scala | 70 ---------- .../apache/spark/sql/hive/HiveContext.scala | 8 +- .../org/apache/spark/sql/hive/HiveQl.scala | 18 ++- ...nctions-1-4a6f611305f58bdbafb2fd89ec62d797 | 4 + ...nctions-2-97cbada21ad9efda7ce9de5891deca7c | 1 + ...nctions-4-4deaa213aff83575bbaf859f79bfdd48 | 2 + .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 16 files changed, 161 insertions(+), 226 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 957bb234e490..0555a6ba83cb 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -167,8 +167,8 @@ intervalLiteral ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)? ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)? ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)? - (millisecond=intervalConstant KW_MILLISECOND)? - (microsecond=intervalConstant KW_MICROSECOND)? + ((intervalConstant KW_MILLISECOND)=> millisecond=intervalConstant KW_MILLISECOND)? + ((intervalConstant KW_MICROSECOND)=> microsecond=intervalConstant KW_MICROSECOND)? -> ^(TOK_INTERVAL ^(TOK_INTERVAL_YEAR_LITERAL $year?) ^(TOK_INTERVAL_MONTH_LITERAL $month?) @@ -505,10 +505,8 @@ identifier functionIdentifier @init { gParent.pushMsg("function identifier", state); } @after { gParent.popMsg(state); } - : db=identifier DOT fn=identifier - -> Identifier[$db.text + "." + $fn.text] - | - identifier + : + identifier (DOT identifier)? -> identifier+ ; principalIdentifier @@ -553,6 +551,8 @@ nonReserved | KW_SNAPSHOT | KW_AUTOCOMMIT | KW_ANTI + | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND + | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS ; //The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e4ffc634e8bf..4374cd7ef720 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -327,6 +327,11 @@ KW_AUTOCOMMIT: 'AUTOCOMMIT'; KW_WEEK: 'WEEK'|'WEEKS'; KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; +KW_CLEAR: 'CLEAR'; +KW_LAZY: 'LAZY'; +KW_CACHE: 'CACHE'; +KW_UNCACHE: 'UNCACHE'; +KW_DFS: 'DFS'; // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index c146ca591488..35bef00351d7 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -371,6 +371,13 @@ TOK_TXN_READ_WRITE; TOK_COMMIT; TOK_ROLLBACK; TOK_SET_AUTOCOMMIT; +TOK_CACHETABLE; +TOK_UNCACHETABLE; +TOK_CLEARCACHE; +TOK_SETCONFIG; +TOK_DFS; +TOK_ADDFILE; +TOK_ADDJAR; } @@ -515,6 +522,11 @@ import java.util.HashMap; xlateMap.put("KW_WEEK", "WEEK"); xlateMap.put("KW_MILLISECOND", "MILLISECOND"); xlateMap.put("KW_MICROSECOND", "MICROSECOND"); + xlateMap.put("KW_CLEAR", "CLEAR"); + xlateMap.put("KW_LAZY", "LAZY"); + xlateMap.put("KW_CACHE", "CACHE"); + xlateMap.put("KW_UNCACHE", "UNCACHE"); + xlateMap.put("KW_DFS", "DFS"); // Operators xlateMap.put("DOT", "."); @@ -687,8 +699,12 @@ catch (RecognitionException e) { // starting rule statement - : explainStatement EOF - | execStatement EOF + : explainStatement EOF + | execStatement EOF + | KW_ADD KW_JAR -> ^(TOK_ADDJAR) + | KW_ADD KW_FILE -> ^(TOK_ADDFILE) + | KW_DFS -> ^(TOK_DFS) + | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG) ; explainStatement @@ -717,6 +733,7 @@ execStatement | deleteStatement | updateStatement | sqlTransactionStatement + | cacheStatement ; loadStatement @@ -1390,7 +1407,7 @@ showStatement @init { pushMsg("show statement", state); } @after { popMsg(state); } : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) - | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) + | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES ^(TOK_FROM $db_name)? showStmtIdentifier?) | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? -> ^(TOK_SHOWCOLUMNS tableName $db_name?) | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) @@ -2438,12 +2455,11 @@ BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly exc sqlTransactionStatement @init { pushMsg("transaction statement", state); } @after { popMsg(state); } - : - startTransactionStatement - | commitStatement - | rollbackStatement - | setAutoCommitStatement - ; + : startTransactionStatement + | commitStatement + | rollbackStatement + | setAutoCommitStatement + ; startTransactionStatement : @@ -2489,3 +2505,31 @@ setAutoCommitStatement /* END user defined transaction boundaries */ + +/* +Table Caching statements. + */ +cacheStatement +@init { pushMsg("cache statement", state); } +@after { popMsg(state); } + : + cacheTableStatement + | uncacheTableStatement + | clearCacheStatement + ; + +cacheTableStatement + : + KW_CACHE (lazy=KW_LAZY)? KW_TABLE identifier (KW_AS selectStatementWithCTE)? -> ^(TOK_CACHETABLE identifier $lazy? selectStatementWithCTE?) + ; + +uncacheTableStatement + : + KW_UNCACHE KW_TABLE identifier -> ^(TOK_UNCACHETABLE identifier) + ; + +clearCacheStatement + : + KW_CLEAR KW_CACHE -> ^(TOK_CLEARCACHE) + ; + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index f531d59a75cf..536c292ab7f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -210,6 +210,28 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { + case Token("TOK_SHOWFUNCTIONS", args) => + // Skip LIKE. + val pattern = args match { + case like :: nodes if like.text.toUpperCase == "LIKE" => nodes + case nodes => nodes + } + + // Extract Database and Function name + pattern match { + case Nil => + ShowFunctions(None, None) + case Token(name, Nil) :: Nil => + ShowFunctions(None, Some(unquoteString(name))) + case Token(db, Nil) :: Token(name, Nil) :: Nil => + ShowFunctions(Some(unquoteString(db)), Some(unquoteString(name))) + case _ => + noParseRule("SHOW FUNCTIONS", node) + } + + case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => + DescribeFunction(functionName, isExtended.nonEmpty) + case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = queryArgs match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala index ec5e71042d4b..ec9812414e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -27,10 +27,10 @@ case class ASTNode( children: List[ASTNode], stream: TokenRewriteStream) extends TreeNode[ASTNode] { /** Cache the number of children. */ - val numChildren = children.size + val numChildren: Int = children.size /** tuple used in pattern matching. */ - val pattern = Some((token.getText, children)) + val pattern: Some[(String, List[ASTNode])] = Some((token.getText, children)) /** Line in which the ASTNode starts. */ lazy val line: Int = { @@ -55,10 +55,16 @@ case class ASTNode( } /** Origin of the ASTNode. */ - override val origin = Origin(Some(line), Some(positionInLine)) + override val origin: Origin = Origin(Some(line), Some(positionInLine)) /** Source text. */ - lazy val source = stream.toString(startIndex, stopIndex) + lazy val source: String = stream.toString(startIndex, stopIndex) + + /** Get the source text that remains after this token. */ + lazy val remainder: String = { + stream.fill() + stream.toString(stopIndex + 1, stream.size() - 1).trim() + } def text: String = token.getText diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b774da33aebe..be28df3a5155 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -204,7 +204,7 @@ class SQLContext private[sql]( protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient - protected[sql] val sqlParser: ParserInterface = new SparkSQLParser(new SparkQl(conf)) + protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) @transient protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index f3e89ef4a71f..f6055306b6c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ @@ -27,6 +28,18 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly protected override def nodeToPlan(node: ASTNode): LogicalPlan = { node match { + case Token("TOK_SETCONFIG", Nil) => + val keyValueSeparatorIndex = node.remainder.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = node.remainder.substring(0, keyValueSeparatorIndex).trim + val value = node.remainder.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (node.remainder.nonEmpty) { + SetCommand(Some(node.remainder -> None)) + } else { + SetCommand(None) + } + // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => ExplainCommand(OneRowRelation) @@ -75,6 +88,24 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly } } + case Token("TOK_CACHETABLE", Token(tableName, Nil) :: args) => + val Seq(lzy, selectAst) = getClauses(Seq("LAZY", "TOK_QUERY"), args) + CacheTableCommand(tableName, selectAst.map(nodeToPlan), lzy.isDefined) + + case Token("TOK_UNCACHETABLE", Token(tableName, Nil) :: Nil) => + UncacheTableCommand(tableName) + + case Token("TOK_CLEARCACHE", Nil) => + ClearCacheCommand + + case Token("TOK_SHOWTABLES", args) => + val databaseName = args match { + case Nil => None + case Token("TOK_FROM", Token(dbName, Nil) :: Nil) :: Nil => Option(dbName) + case _ => noParseRule("SHOW TABLES", node) + } + ShowTablesCommand(databaseName) + case _ => super.nodeToPlan(node) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala deleted file mode 100644 index d2d827156372..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.parsing.combinator.RegexParsers - -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StringType - -/** - * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL - * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. - * - * @param fallback A function that returns the next parser in the chain. This is a call-by-name - * parameter because this allows us to return a different dialect if we - * have to. - */ -class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParser { - - override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = - fallback.parseTableIdentifier(sql) - - // A parser for the key-value part of the "SET [key = [value ]]" syntax - private object SetCommandParser extends RegexParsers { - private val key: Parser[String] = "(?m)[^=]+".r - - private val value: Parser[String] = "(?m).*$".r - - private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)()) - - private val pair: Parser[LogicalPlan] = - (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) - } - - def apply(input: String): LogicalPlan = parseAll(pair, input) match { - case Success(plan, _) => plan - case x => sys.error(x.toString) - } - } - - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val FUNCTION = Keyword("FUNCTION") - protected val FUNCTIONS = Keyword("FUNCTIONS") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") - protected val UNCACHE = Keyword("UNCACHE") - - override protected lazy val start: Parser[LogicalPlan] = - cache | uncache | set | show | desc | others - - private lazy val cache: Parser[LogicalPlan] = - CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan.map(fallback.parsePlan), isLazy.isDefined) - } - - private lazy val uncache: Parser[LogicalPlan] = - ( UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) - } - | CLEAR ~ CACHE ^^^ ClearCacheCommand - ) - - private lazy val set: Parser[LogicalPlan] = - SET ~> restInput ^^ { - case input => SetCommandParser(input) - } - - // It can be the following patterns: - // SHOW FUNCTIONS; - // SHOW FUNCTIONS mydb.func1; - // SHOW FUNCTIONS func1; - // SHOW FUNCTIONS `mydb.a`.`func1.aa`; - private lazy val show: Parser[LogicalPlan] = - ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ { - case _ ~ dbName => ShowTablesCommand(dbName) - } - | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) - case None => logical.ShowFunctions(None, None) - } - ) - - private lazy val desc: Parser[LogicalPlan] = - DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) - } - - private lazy val others: Parser[LogicalPlan] = - wholeInput ^^ { - case input => fallback.parsePlan(input) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 10ccd4b8f60d..989cb2942918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -56,8 +56,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("show functions") { - checkAnswer(sql("SHOW functions"), - FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + def getFunctions(pattern: String): Seq[Row] = { + val regex = java.util.regex.Pattern.compile(pattern) + sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + } + checkAnswer(sql("SHOW functions"), getFunctions(".*")) + Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) + } } test("describe functions") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala deleted file mode 100644 index 313ba18f6aef..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import scala.language.implicitConversions - -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand} - -/** - * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. - */ -private[hive] class ExtendedHiveQlParser(sqlContext: HiveContext) extends AbstractSparkSQLParser { - - val parser = new HiveQl(sqlContext.conf) - - override def parseExpression(sql: String): Expression = parser.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = - parser.parseTableIdentifier(sql) - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - - protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl - - protected lazy val hiveQl: Parser[LogicalPlan] = - restInput ^^ { - case statement => - sqlContext.executionHive.withHiveState { - parser.parsePlan(statement.trim) - } - } - - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => HiveNativeCommand(command.trim) - } - - private lazy val addFile: Parser[LogicalPlan] = - ADD ~ FILE ~> restInput ^^ { - case input => AddFile(input.trim) - } - - private lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case input => AddJar(input.trim) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index eaca3c9269bb..1797ea54f250 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -316,7 +316,9 @@ class HiveContext private[hive]( } protected[sql] override def parseSql(sql: String): LogicalPlan = { - super.parseSql(substitutor.substitute(hiveconf, sql)) + executionHive.withHiveState { + super.parseSql(substitutor.substitute(hiveconf, sql)) + } } override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = @@ -546,9 +548,7 @@ class HiveContext private[hive]( } @transient - protected[sql] override val sqlParser: ParserInterface = { - new SparkSQLParser(new ExtendedHiveQlParser(this)) - } + protected[sql] override val sqlParser: ParserInterface = new HiveQl(conf) @transient private val hivePlanner = new SparkPlanner(this) with HiveStrategies { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 46246f8191db..22841ed2116d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -35,11 +35,12 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.ParseUtils._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.types._ import org.apache.spark.sql.AnalysisException @@ -113,7 +114,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_CREATEROLE", "TOK_DESCDATABASE", - "TOK_DESCFUNCTION", "TOK_DROPDATABASE", "TOK_DROPFUNCTION", @@ -151,7 +151,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_SHOW_TRANSACTIONS", "TOK_SHOWCOLUMNS", "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", "TOK_SHOWINDEXES", "TOK_SHOWLOCKS", "TOK_SHOWPARTITIONS", @@ -244,6 +243,15 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging protected override def nodeToPlan(node: ASTNode): LogicalPlan = { node match { + case Token("TOK_DFS", Nil) => + HiveNativeCommand(node.source + " " + node.remainder) + + case Token("TOK_ADDFILE", Nil) => + AddFile(node.remainder) + + case Token("TOK_ADDJAR", Nil) => + AddJar(node.remainder) + // Special drop table that also uncaches. case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") @@ -558,7 +566,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging protected override def nodeToTransformation( node: ASTNode, - child: LogicalPlan): Option[ScriptTransformation] = node match { + child: LogicalPlan): Option[logical.ScriptTransformation] = node match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -651,7 +659,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging schemaLess) Some( - ScriptTransformation( + logical.ScriptTransformation( inputExprs.map(nodeToExpr), unescapedScript, output, diff --git a/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 b/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 index 175795534fff..f400819b67c2 100644 --- a/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 +++ b/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 @@ -1,4 +1,5 @@ case +cbrt ceil ceiling coalesce @@ -17,3 +18,6 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user diff --git a/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c b/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c index 3c25d656bda1..19458fc86e43 100644 --- a/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c +++ b/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c @@ -2,6 +2,7 @@ assert_true case coalesce current_database +current_date decode e encode diff --git a/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 b/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 index cd2e58d04a4e..1d05f843a7e0 100644 --- a/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 +++ b/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 @@ -1,4 +1,6 @@ +current_date date_add +date_format date_sub datediff to_date diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9e53d8a81e75..0d62d799c8dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.{ExtendedHiveQlParser, HiveContext, HiveQl, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveContext, HiveQl, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ From d702f0c170d5c39df501e173813f8a7718e3b3c6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 27 Jan 2016 14:01:55 -0800 Subject: [PATCH 045/131] [HOTFIX] Fix Scala 2.11 compilation by explicitly marking annotated parameters as vals (SI-8813). Caused by #10835. Author: Andrew Or Closes #10955 from andrewor14/fix-scala211. --- core/src/main/scala/org/apache/spark/Accumulable.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index bde136141f40..52f572b63fa9 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -57,7 +57,8 @@ import org.apache.spark.util.Utils */ class Accumulable[R, T] private ( val id: Long, - @transient initialValue: R, + // SI-8813: This must explicitly be a private val, or else scala 2.11 doesn't compile + @transient private val initialValue: R, param: AccumulableParam[R, T], val name: Option[String], internal: Boolean, From 4a091232122b51f10521a68de8b1d9eb853b563d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 27 Jan 2016 15:35:31 -0800 Subject: [PATCH 046/131] [SPARK-13045] [SQL] Remove ColumnVector.Struct in favor of ColumnarBatch.Row These two classes became identical as the implementation progressed. Author: Nong Li Closes #10952 from nongli/spark-13045. --- .../execution/vectorized/ColumnVector.java | 104 +----------------- .../execution/vectorized/ColumnarBatch.java | 40 ++++--- .../vectorized/ColumnarBatchSuite.scala | 8 +- 3 files changed, 32 insertions(+), 120 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index c119758d68b3..a0bf8734b654 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -210,104 +210,6 @@ public Object get(int ordinal, DataType dataType) { } } - /** - * Holder object to return a struct. This object is intended to be reused. - */ - public static final class Struct extends InternalRow { - // The fields that make up this struct. For example, if the struct had 2 int fields, the access - // to it would be: - // int f1 = fields[0].getInt[rowId] - // int f2 = fields[1].getInt[rowId] - public final ColumnVector[] fields; - - @Override - public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); } - - @Override - public boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } - - public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); } - - @Override - public short getShort(int ordinal) { - throw new NotImplementedException(); - } - - public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); } - public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); } - - @Override - public float getFloat(int ordinal) { - throw new NotImplementedException(); - } - - public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - Array a = getByteArray(ordinal); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); - } - - @Override - public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - return fields[ordinal].getStruct(rowId); - } - - public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); } - - @Override - public MapData getMap(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - throw new NotImplementedException(); - } - - public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); } - public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); } - - @Override - public final int numFields() { - return fields.length; - } - - @Override - public InternalRow copy() { - throw new NotImplementedException(); - } - - @Override - public boolean anyNull() { - throw new NotImplementedException(); - } - - protected int rowId; - - protected Struct(ColumnVector[] fields) { - this.fields = fields; - } - } - /** * Returns the data type of this column. */ @@ -494,7 +396,7 @@ public void reset() { /** * Returns a utility object to get structs. */ - public Struct getStruct(int rowId) { + public ColumnarBatch.Row getStruct(int rowId) { resultStruct.rowId = rowId; return resultStruct; } @@ -749,7 +651,7 @@ public final int appendStruct(boolean isNull) { /** * Reusable Struct holder for getStruct(). */ - protected final Struct resultStruct; + protected final ColumnarBatch.Row resultStruct; /** * Sets up the common state and also handles creating the child columns if this is a nested @@ -779,7 +681,7 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); } this.resultArray = null; - this.resultStruct = new Struct(this.childColumns); + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index d558dae50c22..5a575811fa89 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -86,13 +86,23 @@ public void close() { * performance is lost with this translation. */ public static final class Row extends InternalRow { - private int rowId; + protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; + private final ColumnVector[] columns; + // Ctor used if this is a top level row. private Row(ColumnarBatch parent) { this.parent = parent; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + } + + // Ctor used if this is a struct. + protected Row(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; } /** @@ -103,23 +113,23 @@ public final void markFiltered() { parent.markFiltered(rowId); } + public ColumnVector[] columns() { return columns; } + @Override - public final int numFields() { - return parent.numCols(); - } + public final int numFields() { return columns.length; } @Override /** * Revisit this. This is expensive. */ public final InternalRow copy() { - UnsafeRow row = new UnsafeRow(parent.numCols()); + UnsafeRow row = new UnsafeRow(numFields()); row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize); - for (int i = 0; i < parent.numCols(); i++) { + for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = parent.schema.fields()[i].dataType(); + DataType dt = columns[i].dataType(); if (dt instanceof IntegerType) { row.setInt(i, getInt(i)); } else if (dt instanceof LongType) { @@ -141,7 +151,7 @@ public final boolean anyNull() { @Override public final boolean isNullAt(int ordinal) { - return parent.column(ordinal).getIsNull(rowId); + return columns[ordinal].getIsNull(rowId); } @Override @@ -150,7 +160,7 @@ public final boolean getBoolean(int ordinal) { } @Override - public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); } + public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override public final short getShort(int ordinal) { @@ -159,11 +169,11 @@ public final short getShort(int ordinal) { @Override public final int getInt(int ordinal) { - return parent.column(ordinal).getInt(rowId); + return columns[ordinal].getInt(rowId); } @Override - public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); } + public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override public final float getFloat(int ordinal) { @@ -172,7 +182,7 @@ public final float getFloat(int ordinal) { @Override public final double getDouble(int ordinal) { - return parent.column(ordinal).getDouble(rowId); + return columns[ordinal].getDouble(rowId); } @Override @@ -182,7 +192,7 @@ public final Decimal getDecimal(int ordinal, int precision, int scale) { @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId); + ColumnVector.Array a = columns[ordinal].getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } @@ -198,12 +208,12 @@ public final CalendarInterval getInterval(int ordinal) { @Override public final InternalRow getStruct(int ordinal, int numFields) { - return parent.column(ordinal).getStruct(rowId); + return columns[ordinal].getStruct(rowId); } @Override public final ArrayData getArray(int ordinal) { - return parent.column(ordinal).getArray(rowId); + return columns[ordinal].getArray(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 215ca9ab6b77..67cc08b6fc8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -439,10 +439,10 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(1, 5.67) val s = column.getStruct(0) - assert(s.fields(0).getInt(0) == 123) - assert(s.fields(0).getInt(1) == 456) - assert(s.fields(1).getDouble(0) == 3.45) - assert(s.fields(1).getDouble(1) == 5.67) + assert(s.columns()(0).getInt(0) == 123) + assert(s.columns()(0).getInt(1) == 456) + assert(s.columns()(1).getDouble(0) == 3.45) + assert(s.columns()(1).getDouble(1) == 5.67) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) From c2204436a15838f2dce44e3cfb0fe58236ef6196 Mon Sep 17 00:00:00 2001 From: James Lohse Date: Thu, 28 Jan 2016 10:50:50 +0000 Subject: [PATCH 047/131] Provide same info as in spark-submit --help this is stated for --packages and --repositories. Without stating it for --jars, people expect a standard java classpath to work, with expansion and using a different delimiter than a comma. Currently this is only state in the --help for spark-submit "Comma-separated list of local jars to include on the driver and executor classpaths." Author: James Lohse Closes #10890 from jimlohse/patch-1. --- docs/submitting-applications.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index acbb0f298fe4..413532f2f6cf 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,8 +177,9 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. Spark uses the following URL scheme to allow -different strategies for disseminating jars: +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. + +Spark uses the following URL scheme to allow different strategies for disseminating jars: - **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor pulls the file from the driver HTTP server. From 415d0a859b7a76f3a866ec62ab472c4050f2a01b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 28 Jan 2016 12:26:03 -0800 Subject: [PATCH 048/131] [SPARK-12818][SQL] Specialized integral and string types for Count-min Sketch This PR is a follow-up of #10911. It adds specialized update methods for `CountMinSketch` so that we can avoid doing internal/external row format conversion in `DataFrame.countMinSketch()`. Author: Cheng Lian Closes #10968 from liancheng/cms-specialized. --- .../spark/util/sketch/CountMinSketch.java | 34 +++++++++- .../spark/util/sketch/CountMinSketchImpl.java | 35 ++++++++-- .../spark/sql/DataFrameStatFunctions.scala | 65 +++++++++++-------- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 5692e574d4c7..f0aac5bb00df 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -115,15 +115,45 @@ int getVersionNumber() { public abstract long totalCount(); /** - * Adds 1 to {@code item}. + * Increments {@code item}'s count by one. */ public abstract void add(Object item); /** - * Adds {@code count} to {@code item}. + * Increments {@code item}'s count by {@code count}. */ public abstract void add(Object item, long count); + /** + * Increments {@code item}'s count by one. + */ + public abstract void addLong(long item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addLong(long item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addString(String item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addString(String item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addBinary(byte[] item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addBinary(byte[] item, long count); + /** * Returns the estimated frequency of {@code item}. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e49ae22906c4..c0631c6778df 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -25,7 +25,6 @@ import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; -import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; @@ -146,27 +145,49 @@ public void add(Object item, long count) { } } - private void addString(String item, long count) { + @Override + public void addString(String item) { + addString(item, 1); + } + + @Override + public void addString(String item, long count) { + addBinary(Utils.getBytesFromUTF8String(item), count); + } + + @Override + public void addLong(long item) { + addLong(item, 1); + } + + @Override + public void addLong(long item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } - int[] buckets = getHashBuckets(item, depth, width); - for (int i = 0; i < depth; ++i) { - table[i][buckets[i]] += count; + table[i][hash(item, i)] += count; } totalCount += count; } - private void addLong(long item, long count) { + @Override + public void addBinary(byte[] item) { + addBinary(item, 1); + } + + @Override + public void addBinary(byte[] item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { - table[i][hash(item, i)] += count; + table[i][buckets[i]] += count; } totalCount += count; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b0b6995a2214..bb3cc02800d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** @@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * - * * @param col1 The name of the first column. Distinct items will make the first item of * each row. * @param col2 The name of the second column. Distinct items will make the column names @@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { val singleCol = df.select(col) val colType = singleCol.schema.head.dataType - require( - colType == StringType || colType.isInstanceOf[IntegralType], - s"Count-min Sketch only supports string type and integral types, " + - s"and does not support type $colType." - ) + val updater: (CountMinSketch, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` + // instead of `addString` to avoid unnecessary conversion. + case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes) + case ByteType => (sketch, row) => sketch.addLong(row.getByte(0)) + case ShortType => (sketch, row) => sketch.addLong(row.getShort(0)) + case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0)) + case LongType => (sketch, row) => sketch.addLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + } - singleCol.rdd.aggregate(zero)( - (sketch: CountMinSketch, row: Row) => { - sketch.add(row.get(0)) + singleCol.queryExecution.toRdd.aggregate(zero)( + (sketch: CountMinSketch, row: InternalRow) => { + updater(sketch, row) sketch }, - - (sketch1: CountMinSketch, sketch2: CountMinSketch) => { - sketch1.mergeInPlace(sketch2) - } + (sketch1, sketch2) => sketch1.mergeInPlace(sketch2) ) } @@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { require(colType == StringType || colType.isInstanceOf[IntegralType], s"Bloom filter only supports string type and integral types, but got $colType.") - val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { - (filter, row) => - // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` - // instead of `putString` to avoid unnecessary conversion. - filter.putBinary(row.getUTF8String(0).getBytes) - filter - } else { - (filter, row) => - // TODO: specialize it. - filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) - filter + val updater: (BloomFilter, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes) + case ByteType => (filter, row) => filter.putLong(row.getByte(0)) + case ShortType => (filter, row) => filter.putLong(row.getShort(0)) + case IntegerType => (filter, row) => filter.putLong(row.getInt(0)) + case LongType => (filter, row) => filter.putLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Bloom filter only supports string type and integral types, " + + s"and does not support type $colType." + ) } - singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + singleCol.queryExecution.toRdd.aggregate(zero)( + (filter: BloomFilter, row: InternalRow) => { + updater(filter, row) + filter + }, + (filter1, filter2) => filter1.mergeInPlace(filter2) + ) } } From 676803963fcc08aa988aa6f14be3751314e006ca Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 28 Jan 2016 13:45:28 -0800 Subject: [PATCH 049/131] [SPARK-12926][SQL] SQLContext to display warning message when non-sql configs are being set Users unknowingly try to set core Spark configs in SQLContext but later realise that it didn't work. eg. sqlContext.sql("SET spark.shuffle.memoryFraction=0.4"). This PR adds a warning message when such operations are done. Author: Tejas Patil Closes #10849 from tejasapatil/SPARK-12926. --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index c9ba6700998c..eb9da0bd4fd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.util.Utils @@ -519,7 +520,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf { +private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -628,7 +629,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon // Only verify configs in the SQLConf object entry.valueConverter(value) } - settings.put(key, value) + setConfWithCheck(key, value) } /** Set the given Spark SQL configuration property. */ @@ -636,7 +637,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon require(entry != null, "entry cannot be null") require(value != null, s"value cannot be null for key: ${entry.key}") require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - settings.put(entry.key, entry.stringConverter(value)) + setConfWithCheck(entry.key, entry.stringConverter(value)) } /** Return the value of Spark SQL configuration property for the given key. */ @@ -699,6 +700,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon }.toSeq } + private def setConfWithCheck(key: String, value: String): Unit = { + if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) { + logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value") + } + settings.put(key, value) + } + private[spark] def unsetConf(key: String): Unit = { settings.remove(key) } From cc18a7199240bf3b03410c1ba6704fe7ce6ae38e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 13:51:55 -0800 Subject: [PATCH 050/131] [SPARK-13031] [SQL] cleanup codegen and improve test coverage 1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. Author: Davies Liu Closes #10944 from davies/gen_refactor. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++++++++------ .../aggregate/TungstenAggregate.scala | 88 +++++--- .../spark/sql/execution/basicOperators.scala | 96 +++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 334 insertions(+), 202 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad37..e6704cf8bb1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,14 +144,23 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = { + if (freshNamePrefix == "") { + s"$name${curId.getAndIncrement}" + } else { + s"${freshNamePrefix}_$name${curId.getAndIncrement}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6e..ec31db19b94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de980..ef81ba60f049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns the RDD of InternalRow which generates the input rows. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def upstream(): RDD[InternalRow] + + /** + * Returns Java source code to process the rows from upstream. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent + ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) } + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + ctx.freshNamePrefix = nodeName + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, evals)} + """.stripMargin + } else { + doConsume(ctx, input) + } + } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doPrepare(): Unit = { + child.prepare() + } - override def supportCodegen: Boolean = true + override def doExecute(): RDD[InternalRow] = { + child.execute() + } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def supportCodegen: Boolean = false + + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" - | while (input.hasNext()) { + s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() --------> produce() + * doExecute() ---------> upstream() -------> upstream() ------> execute() + * | + * -----------------> produce() * | * doProduce() -------> produce() * | - * doProduce() ---> execute() + * doProduce() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| * | - * doConsume() <----- consume() + * doConsume() + * | + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) + val code = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() { + protected void processNext() throws java.io.IOException { $code - } + } } - """ + """ + // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - rdd.mapPartitions { iter => + plan.upstream().mapPartitions { iter => + val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | currentRow = unsafeRow; + | currentRow = $row; | return; """.stripMargin + } else { + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } } } @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d25..cbd2634b8900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,9 +117,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } // The variables used as aggregation buffer @@ -127,7 +125,11 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -137,50 +139,80 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = + // generate variables for output + val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) + } else { + // output the aggregate buffer directly + (bufVars, "") + } + + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, s""" - | if (!$initAgg) { - | $initAgg = true; - | + | private void $doAgg() { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } - """.stripMargin + """.stripMargin) - (rdd, source) + s""" + | if (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | $genResult + | + | ${consume(ctx, resultVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val updates = updateExpr.zipWithIndex.map { case (e, i) => + val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -190,7 +222,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | ${updates.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5e..e7a73d5fbb4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -153,17 +161,21 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { - val initTerm = ctx.freshName("range_initRange") + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("range_partitionEnd") + val partitionEnd = ctx.freshName("partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("range_number") + val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("range_overflow") + val overflow = ctx.freshName("overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("range_value") + val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -172,38 +184,42 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } + """.stripMargin) - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } + | initRange(((InternalRow) input.next()).getInt(0)); | } else { | return; | } @@ -218,12 +234,6 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 989cb2942918..51a50c1fa30e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } - - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cbae19ebd269..82f6811503c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48143762cac..7d6bff8295d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9a24a2487a25..a3e5243b68ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From df78a934a07a4ce5af43243be9ba5fe60b91eee6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 28 Jan 2016 14:29:47 -0800 Subject: [PATCH 051/131] [SPARK-9835][ML] Implement IterativelyReweightedLeastSquares solver Implement ```IterativelyReweightedLeastSquares``` solver for GLM. I consider it as a solver rather than estimator, it only used internal so I keep it ```private[ml]```. There are two limitations in the current implementation compared with R: * It can not support ```Tuple``` as response for ```Binomial``` family, such as the following code: ``` glm( cbind(using, notUsing) ~ age + education + wantsMore , family = binomial) ``` * It does not support ```offset```. Because I considered that ```RFormula``` did not support ```Tuple``` as label and ```offset``` keyword, so I simplified the implementation. But to add support for these two functions is not very hard, I can do it in follow-up PR if it is necessary. Meanwhile, we can also add R-like statistic summary for IRLS. The implementation refers R, [statsmodels](https://github.com/statsmodels/statsmodels) and [sparkGLM](https://github.com/AlteryxLabs/sparkGLM). Please focus on the main structure and overpass minor issues/docs that I will update later. Any comments and opinions will be appreciated. cc mengxr jkbradley Author: Yanbo Liang Closes #10639 from yanboliang/spark-9835. --- .../IterativelyReweightedLeastSquares.scala | 108 ++++++++++ .../spark/ml/optim/WeightedLeastSquares.scala | 7 +- ...erativelyReweightedLeastSquaresSuite.scala | 200 ++++++++++++++++++ 3 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala new file mode 100644 index 000000000000..6aa44e6ba723 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.Logging +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg._ +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[IterativelyReweightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + */ +private[ml] class IterativelyReweightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double) extends Serializable + +/** + * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve + * certain optimization problems by an iterative method. In each step of the iterations, it + * involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]]. + * It can be used to find maximum likelihood estimates of a generalized linear model (GLM), + * find M-estimator in robust regression and other optimization problems. + * + * @param initialModel the initial guess model. + * @param reweightFunc the reweight function which is used to update offsets and weights + * at each iteration. + * @param fitIntercept whether to fit intercept. + * @param regParam L2 regularization parameter used by WLS. + * @param maxIter maximum number of iterations. + * @param tol the convergence tolerance. + * + * @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares + * for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives, + * Journal of the Royal Statistical Society. Series B, 1984.]] + */ +private[ml] class IterativelyReweightedLeastSquares( + val initialModel: WeightedLeastSquaresModel, + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val fitIntercept: Boolean, + val regParam: Double, + val maxIter: Int, + val tol: Double) extends Logging with Serializable { + + def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + + var converged = false + var iter = 0 + + var model: WeightedLeastSquaresModel = initialModel + var oldModel: WeightedLeastSquaresModel = null + + while (iter < maxIter && !converged) { + + oldModel = model + + // Update offsets and weights using reweightFunc + val newInstances = instances.map { instance => + val (newOffset, newWeight) = reweightFunc(instance, oldModel) + Instance(newOffset, newWeight, instance.features) + } + + // Estimate new model + model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false, + standardizeLabel = false).fit(newInstances) + + // Check convergence + val oldCoefficients = oldModel.coefficients + val coefficients = model.coefficients + BLAS.axpy(-1.0, coefficients, oldCoefficients) + val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => + math.max(math.abs(x), math.abs(y)) + } + val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) + + if (maxTol < tol) { + converged = true + logInfo(s"IRLS converged in $iter iterations.") + } + + logInfo(s"Iteration $iter : relative tolerance = $maxTol") + iter = iter + 1 + + if (iter == maxIter) { + logInfo(s"IRLS reached the max number of iterations: $maxIter.") + } + + } + + new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 797870eb8ce8..61b364213181 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, val intercept: Double, - val diagInvAtWA: DenseVector) extends Serializable + val diagInvAtWA: DenseVector) extends Serializable { + + def predict(features: Vector): Double = { + BLAS.dot(coefficients, features) + intercept + } +} /** * Weighted least squares solver via normal equation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala new file mode 100644 index 000000000000..604021220a13 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances1: RDD[Instance] = _ + private var instances2: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) + b <- c(1, 0, 1, 0) + w <- c(1, 2, 3, 4) + */ + instances1 = sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ), 2) + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + */ + instances2 = sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + } + + test("IRLS against GLM with Binomial errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="binomial", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.30216651 -0.04452045 + [1] 3.5651651 -1.2334085 -0.7348971 + */ + val expected = Seq( + Vectors.dense(0.0, -0.30216651, -0.04452045), + Vectors.dense(3.5651651, -1.2334085, -0.7348971)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val newInstances = instances1.map { instance => + val mu = (instance.label + 0.5) / 2.0 + val eta = math.log(mu / (1.0 - mu)) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against GLM with Poisson errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="poisson", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.09607792 0.18375613 + [1] 6.299947 3.324107 -1.081766 + */ + val expected = Seq( + Vectors.dense(0.0, -0.09607792, 0.18375613), + Vectors.dense(6.299947, 3.324107, -1.081766)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val yMean = instances2.map(_.label).mean + val newInstances = instances2.map { instance => + val mu = (instance.label + yMean) / 2.0 + val eta = math.log(mu) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against L1Regression") { + /* + R code: + + library(quantreg) + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- rq(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] 1.266667 0.400000 + [1] 29.5 17.0 -5.5 + */ + val expected = Seq( + Vectors.dense(0.0, 1.266667, 0.400000), + Vectors.dense(29.5, 17.0, -5.5)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(instances2) + val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} + +object IterativelyReweightedLeastSquaresSuite { + + def BinomialReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) + val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val w = mu * (1 - mu) * instance.weight + (z, w) + } + + def PoissonReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = math.exp(eta) + val z = eta + (instance.label - mu) / mu + val w = mu * instance.weight + (z, w) + } + + def L1RegressionReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val e = math.max(math.abs(eta - instance.label), 1e-7) + val w = 1 / e + val y = instance.label + (y, w) + } +} From abae889f08eb412cb897e4e63614ec2c93885ffd Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 28 Jan 2016 15:20:16 -0800 Subject: [PATCH 052/131] [SPARK-12401][SQL] Add integration tests for postgres enum types We can handle posgresql-specific enum types as strings in jdbc. So, we should just add tests and close the corresponding JIRA ticket. Author: Takeshi YAMAMURO Closes #10596 from maropu/AddTestsInIntegration. --- .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 7d011be37067..72bda8fe1ef1 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import java.util.Properties import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.tags.DockerTest @DockerTest @@ -39,12 +39,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.setCatalog("foo") + conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[], c12 real[])").executeUpdate() + + "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate() } test("Type mapping for various types") { @@ -52,7 +53,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 13) + assert(types.length == 14) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -66,22 +67,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(classOf[Seq[Double]].isAssignableFrom(types(12))) + assert(classOf[String].isAssignableFrom(types(13))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) assert(rows(0).getLong(3) == 123456789012345L) - assert(rows(0).getBoolean(4) == false) + assert(!rows(0).getBoolean(4)) // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) - assert(rows(0).getBoolean(7) == true) + assert(rows(0).getBoolean(7)) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) + assert(rows(0).getString(13) == "d1") } test("Basic write test") { From 3a40c0e575fd4215302ea60c9821d31a5a138b8a Mon Sep 17 00:00:00 2001 From: Brandon Bradley Date: Thu, 28 Jan 2016 15:25:57 -0800 Subject: [PATCH 053/131] [SPARK-12749][SQL] add json option to parse floating-point types as DecimalType I tried to add this via `USE_BIG_DECIMAL_FOR_FLOATS` option from Jackson with no success. Added test for non-complex types. Should I add a test for complex types? Author: Brandon Bradley Closes #10936 from blbradley/spark-12749. --- python/pyspark/sql/readwriter.py | 2 ++ .../apache/spark/sql/DataFrameReader.scala | 2 ++ .../datasources/json/InferSchema.scala | 8 ++++-- .../datasources/json/JSONOptions.scala | 2 ++ .../datasources/json/JsonSuite.scala | 28 +++++++++++++++++++ 5 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 0b20022b14b8..b1453c637f79 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -152,6 +152,8 @@ def json(self, path, schema=None): You can set the following JSON-specific options to deal with non-standard JSON files: * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ type + * `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal \ + type * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 634c1bd4739b..2e0c6c7df967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -252,6 +252,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * You can set the following JSON-specific options to deal with non-standard JSON files: *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal + * type
  • *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 44d5e4ff7ec8..8b773ddfcb65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -134,8 +134,12 @@ private[json] object InferSchema { val v = parser.getDecimalValue DecimalType(v.precision(), v.scale()) case FLOAT | DOUBLE => - // TODO(davies): Should we use decimal if possible? - DoubleType + if (configOptions.floatAsBigDecimal) { + val v = parser.getDecimalValue + DecimalType(v.precision(), v.scale()) + } else { + DoubleType + } } case VALUE_TRUE | VALUE_FALSE => BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index fe5b20697e40..31a95ed46121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -34,6 +34,8 @@ private[sql] class JSONOptions( parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val floatAsBigDecimal = + parameters.get("floatAsBigDecimal").map(_.toBoolean).getOrElse(false) val allowComments = parameters.get("allowComments").map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 00eaeb0d34e8..dd83a0e36f6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -771,6 +771,34 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Loading a JSON dataset floatAsBigDecimal returns schema with float types as BigDecimal") { + val jsonDF = sqlContext.read.option("floatAsBigDecimal", "true").json(primitiveFieldAndType) + + val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DecimalType(17, -292), true) :: + StructField("integer", LongType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(BigDecimal("92233720368547758070"), + true, + BigDecimal("1.7976931348623157E308"), + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() From 4637fc08a3733ec313218fb7e4d05064d9a6262d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jan 2016 16:25:21 -0800 Subject: [PATCH 054/131] [SPARK-11955][SQL] Mark optional fields in merging schema for safely pushdowning filters in Parquet JIRA: https://issues.apache.org/jira/browse/SPARK-11955 Currently we simply skip pushdowning filters in parquet if we enable schema merging. However, we can actually mark particular fields in merging schema for safely pushdowning filters in parquet. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #9940 from viirya/safe-pushdown-parquet-filters. --- .../org/apache/spark/sql/types/Metadata.scala | 5 +++ .../apache/spark/sql/types/StructType.scala | 34 ++++++++++++--- .../spark/sql/types/DataTypeSuite.scala | 14 ++++-- .../datasources/parquet/ParquetFilters.scala | 37 +++++++++++----- .../datasources/parquet/ParquetRelation.scala | 13 +++--- .../parquet/ParquetFilterSuite.scala | 43 ++++++++++++++++++- 6 files changed, 117 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 9e0f9943bc63..66f123682e11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -273,4 +273,9 @@ class MetadataBuilder { map.put(key, value) this } + + def remove(key: String): this.type = { + map.remove(key) + this + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3bd733fa2d26..da0c92864e9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -334,6 +334,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { + private[sql] val metadataKeyForOptionalField = "_OPTIONAL_" + override private[sql] def defaultConcreteType: DataType = new StructType override private[sql] def acceptsType(other: DataType): Boolean = { @@ -359,6 +361,18 @@ object StructType extends AbstractDataType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + def removeMetadata(key: String, dt: DataType): DataType = + dt match { + case StructType(fields) => + val newFields = fields.map { f => + val mb = new MetadataBuilder() + f.copy(dataType = removeMetadata(key, f.dataType), + metadata = mb.withMetadata(f.metadata).remove(key).build()) + } + StructType(newFields) + case _ => dt + } + private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), @@ -376,24 +390,32 @@ object StructType extends AbstractDataType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + // This metadata will record the fields that only exist in one of two StructTypes + val optionalMeta = new MetadataBuilder() val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse { + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + Some(leftField.copy(metadata = optionalMeta.build())) + } .foreach(newFields += _) } val leftMapped = fieldsMap(leftFields) rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach(newFields += _) + .foreach { f => + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + newFields += f.copy(metadata = optionalMeta.build()) + } StructType(newFields) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 706ecd29d135..c2bbca7c33f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType(List()) val merged = left.merge(right) - assert(merged === left) + assert(DataType.equalsIgnoreCompatibleNullability(merged, left)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where left is empty") { @@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(right === merged) - + assert(DataType.equalsIgnoreCompatibleNullability(merged, right)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where both are non-empty") { @@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(merged === expected) + assert(DataType.equalsIgnoreCompatibleNullability(merged, expected)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where right contains type conflict") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index e9b734b0abf5..5a5cb5cf03d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -207,11 +207,26 @@ private[sql] object ParquetFilters { */ } + /** + * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField. + * These fields only exist in one side of merged schemas. Due to that, we can't push down filters + * using such fields, otherwise Parquet library will throw exception. Here we filter out such + * fields. + */ + private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { + case StructType(fields) => + fields.filter { f => + !f.metadata.contains(StructType.metadataKeyForOptionalField) || + !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) + }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) } + case _ => Array.empty[(String, DataType)] + } + /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + val dataTypeOf = getFieldMap(schema).toMap relaxParquetValidTypeMap @@ -231,29 +246,29 @@ private[sql] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) => + case sources.IsNull(name) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.IsNotNull(name) => + case sources.IsNotNull(name) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.EqualTo(name, value) => + case sources.EqualTo(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) => + case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) => + case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) => + case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThan(name, value) => + case sources.LessThan(name, value) if dataTypeOf.contains(name) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) => + case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThan(name, value) => + case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => makeGt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) => + case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.In(name, valueSet) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b460ec1d2604..f87590095d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -258,7 +258,12 @@ private[sql] class ParquetRelation( job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) - CatalystWriteSupport.setSchema(dataSchema, conf) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushdowning filters. + val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + CatalystWriteSupport.setSchema(dataSchemaToWrite, conf) // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) // and `CatalystWriteSupport` (writing actual rows to Parquet files). @@ -304,10 +309,6 @@ private[sql] class ParquetRelation( val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - // When merging schemas is enabled and the column of the given filter does not exist, - // Parquet emits an exception which is an issue of Parquet (PARQUET-389). - val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown - // Parquet row group size. We will use this value as the value for // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value // of these flags are smaller than the parquet row group size. @@ -321,7 +322,7 @@ private[sql] class ParquetRelation( dataSchema, parquetBlockSize, useMetadataCache, - safeParquetFilterPushDown, + parquetFilterPushDown, assumeBinaryIsString, assumeInt96IsTimestamp) _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 97c5313f0fef..1796b3af0e37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -379,9 +380,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( - sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), - (1 to 1).map(i => Row(i, i.toString, null))) + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = sqlContext.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = sqlContext.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) } } } From b9dfdcc63bb12bc24de96060e756889c2ceda519 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 17:01:12 -0800 Subject: [PATCH 055/131] Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage" This reverts commit cc18a7199240bf3b03410c1ba6704fe7ce6ae38e. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++------------ .../aggregate/TungstenAggregate.scala | 88 +++----- .../spark/sql/execution/basicOperators.scala | 96 ++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 202 insertions(+), 334 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f..2747c315ad37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,23 +144,14 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * A prefix used to generate fresh name. - */ - var freshNamePrefix = "" - /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" - } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" - } + def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b..d9fe76133c6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f049..57f4945de980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns the RDD of InternalRow which generates the input rows. + * Returns an input RDD of InternalRow and Java source code to process them. */ - def upstream(): RDD[InternalRow] - - /** - * Returns Java source code to process the rows from upstream. - */ - def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { this.parent = parent - ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): String + protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) /** - * Consume the columns generated from current SparkPlan, call it's parent. + * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. */ - def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null) { - assert(input.length == output.length) - } - parent.consumeChild(ctx, this, input, row) + protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { + assert(columns.length == output.length) + parent.doConsume(ctx, this, columns) } - /** - * Consume the columns generated from it's child, call doConsume() or emit the rows. - */ - def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - ctx.freshNamePrefix = nodeName - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} - """.stripMargin - } else { - doConsume(ctx, input) - } - } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String } @@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def doPrepare(): Unit = { - child.prepare() - } - override def doExecute(): RDD[InternalRow] = { - child.execute() - } + override def supportCodegen: Boolean = true - override def supportCodegen: Boolean = false - - override def upstream(): RDD[InternalRow] = { - child.execute() - } - - override def doProduce(ctx: CodegenContext): String = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - s""" - | while (input.hasNext()) { + val code = s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin + (child.execute(), code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() ---------> upstream() -------> upstream() ------> execute() - * | - * -----------------> produce() + * doExecute() --------> produce() * | * doProduce() -------> produce() * | - * doProduce() + * doProduce() ---> execute() * | * consume() - * consumeChild() <-----------| + * doConsume() ------------| * | - * doConsume() - * | - * consumeChild() <----- consume() + * doConsume() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -206,48 +162,37 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { - override def supportCodegen: Boolean = false - override def output: Seq[Attribute] = plan.output - override def outputPartitioning: Partitioning = plan.outputPartitioning - override def outputOrdering: Seq[SortOrder] = plan.outputOrdering - - override def doPrepare(): Unit = { - plan.prepare() - } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val code = plan.produce(ctx, this) + val (rdd, code) = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} + private Object[] references; + ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() throws java.io.IOException { + protected void processNext() { $code - } + } } - """ - + """ // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - plan.upstream().mapPartitions { iter => - + rdd.mapPartitions { iter => val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -258,47 +203,29 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def upstream(): RDD[InternalRow] = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { throw new UnsupportedOperationException } - override def doProduce(ctx: CodegenContext): String = { - throw new UnsupportedOperationException - } - - override def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - - if (row != null) { - // There is an UnsafeRow already + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | currentRow = $row; + | ${code.code.trim} + | currentRow = ${code.value}; | return; - """.stripMargin + """.stripMargin } else { - assert(input != null) - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns - s""" - | currentRow = unsafeRow; - | return; - """.stripMargin - } + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin } } @@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => + plan.children.exists(supportCodegen) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - val input = apply(p) // collapse them recursively - inputs += input - InputAdapter(input) + inputs += p + InputAdapter(p) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index cbd2634b8900..23e54f344d25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,7 +117,9 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) } // The variables used as aggregation buffer @@ -125,11 +127,7 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -139,80 +137,50 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { - // evaluate aggregate results - ctx.currentVars = bufVars - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttrs).gen(ctx) - } - // evaluate result expressions - ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).gen(ctx) - } - (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} - """.stripMargin) - } else { - // output the aggregate buffer directly - (bufVars, "") - } - - val doAgg = ctx.freshName("doAgg") - ctx.addNewFunction(doAgg, + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = s""" - | private void $doAgg() { + | if (!$initAgg) { + | $initAgg = true; + | | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} | } - """.stripMargin) + """.stripMargin - s""" - | if (!$initAgg) { - | $initAgg = true; - | $doAgg(); - | - | // output the result - | $genResult - | - | ${consume(ctx, resultVars)} - | } - """.stripMargin + (rdd, source) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => - e.mode match { - case Partial | Complete => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions - case PartialMerge | Final => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions - } + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) } + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val updates = updateExpr.zipWithIndex.map { case (e, i) => - val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -222,7 +190,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${updates.mkString("")} + | ${codes.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb4b..6deb72adad5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -161,21 +153,17 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - override def upstream(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) - } - - protected override def doProduce(ctx: CodegenContext): String = { - val initTerm = ctx.freshName("initRange") + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initTerm = ctx.freshName("range_initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("partitionEnd") + val partitionEnd = ctx.freshName("range_partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("number") + val number = ctx.freshName("range_number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("overflow") + val overflow = ctx.freshName("range_overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("value") + val value = ctx.freshName("range_value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -184,42 +172,38 @@ case class Range( s"$number > $partitionEnd" } - ctx.addNewFunction("initRange", - s""" - | private void initRange(int idx) { - | $BigInt index = $BigInt.valueOf(idx); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } - | } - """.stripMargin) + val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) - s""" + val code = s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | initRange(((InternalRow) input.next()).getInt(0)); + | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } | } else { | return; | } @@ -234,6 +218,12 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin + + (rdd, code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa30e..989cb2942918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,61 +1939,58 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // TODO: support subexpression elimination in whole stage codegen - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) - - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) - } + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503c2..cbae19ebd269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,24 +335,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) - } + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7d6bff8295d2..d48143762cac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a3e5243b68ab..9a24a2487a25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() - } + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) assert(metrics(0) == 1) From 66449b8dcdbc3dca126c34b42c4d0419c7648696 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jan 2016 22:20:52 -0800 Subject: [PATCH 056/131] [SPARK-12968][SQL] Implement command to set current database JIRA: https://issues.apache.org/jira/browse/SPARK-12968 Implement command to set current database. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10916 from viirya/ddl-use-database. --- .../spark/sql/catalyst/analysis/Catalog.scala | 4 ++++ .../org/apache/spark/sql/execution/SparkQl.scala | 3 +++ .../apache/spark/sql/execution/commands.scala | 10 ++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 -- .../spark/sql/hive/client/ClientInterface.scala | 3 +++ .../spark/sql/hive/client/ClientWrapper.scala | 9 +++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 16 ++++++++++++++++ 9 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index a8f89ce6de45..f2f9ec59417e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -46,6 +46,10 @@ trait Catalog { def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan + def setCurrentDatabase(databaseName: String): Unit = { + throw new UnsupportedOperationException + } + /** * Returns tuples of (tableName, isTemporary) for all tables in the given database. * isTemporary is a Boolean value indicates if a table is a temporary or not. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index f6055306b6c9..a5bd8ee42dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -55,6 +55,9 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => + SetDatabaseCommand(cleanIdentifier(database)) + case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val Some(tableType) :: formatted :: extended :: pretty :: Nil = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 3cfa3dfd9c7e..703e4643cbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -408,3 +408,13 @@ case class DescribeFunction( } } } + +case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.catalog.setCurrentDatabase(databaseName) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index ab31d45a79a2..72da266da4d0 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" - -> "OK", + -> "", "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a9c0e9ab7cae..848aa4ec6fe5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -711,6 +711,10 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } override def unregisterAllTables(): Unit = {} + + override def setCurrentDatabase(databaseName: String): Unit = { + client.setCurrentDatabase(databaseName) + } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 22841ed2116d..752c037a842a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -155,8 +155,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_SHOWLOCKS", "TOK_SHOWPARTITIONS", - "TOK_SWITCHDATABASE", - "TOK_UNLOCKTABLE" ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 9d9a55edd731..4eec3fef7408 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -109,6 +109,9 @@ private[hive] trait ClientInterface { /** Returns the name of the active database. */ def currentDatabase: String + /** Sets the name of current database. */ + def setCurrentDatabase(databaseName: String): Unit + /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ def getDatabase(name: String): HiveDatabase = { getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index ce7a305d437a..5307e924e7e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -229,6 +230,14 @@ private[hive] class ClientWrapper( state.getCurrentDatabase } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + if (getDatabaseOption(databaseName).isDefined) { + state.setCurrentDatabase(databaseName) + } else { + throw new NoSuchDatabaseException + } + } + override def createDatabase(database: HiveDatabase): Unit = withHiveState { client.createDatabase( new Database( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4659d745fe78..9632d27a2ffc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin @@ -1262,6 +1263,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("use database") { + val currentDatabase = sql("select current_database()").first().getString(0) + + sql("CREATE DATABASE hive_test_db") + sql("USE hive_test_db") + assert("hive_test_db" == sql("select current_database()").first().getString(0)) + + intercept[NoSuchDatabaseException] { + sql("USE not_existing_db") + } + + sql(s"USE $currentDatabase") + assert(currentDatabase == sql("select current_database()").first().getString(0)) + } + test("lookup hive UDF in another thread") { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") From 721ced28b522cc00b45ca7fa32a99e80ad3de2f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Jan 2016 22:42:43 -0800 Subject: [PATCH 057/131] [SPARK-13067] [SQL] workaround for a weird scala reflection problem A simple workaround to avoid getting parameter types when convert a logical plan to json. Author: Wenchen Fan Closes #10970 from cloud-fan/reflection. --- .../spark/sql/catalyst/ScalaReflection.scala | 25 ++++++++++++++++--- .../spark/sql/catalyst/trees/TreeNode.scala | 4 +-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 643228d0eb27..e5811efb436a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -601,6 +601,20 @@ object ScalaReflection extends ScalaReflection { getConstructorParameters(t) } + /** + * Returns the parameter names for the primary constructor of this class. + * + * Logically we should call `getConstructorParameters` and throw away the parameter types to get + * parameter names, however there are some weird scala reflection problems and this method is a + * workaround to avoid getting parameter types. + */ + def getConstructorParameterNames(cls: Class[_]): Seq[String] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + constructParams(t).map(_.name.toString) + } + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } @@ -745,6 +759,12 @@ trait ScalaReflection { def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe + constructParams(tpe).map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } + + protected def constructParams(tpe: Type): Seq[Symbol] = { val constructorSymbol = tpe.member(nme.CONSTRUCTOR) val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramss @@ -758,9 +778,6 @@ trait ScalaReflection { primaryConstructorSymbol.get.asMethod.paramss } } - - params.flatten.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - } + params.flatten } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 57e1a3c9eb22..2df0683f9fa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -512,7 +512,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } protected def jsonFields: List[JField] = { - val fieldNames = getConstructorParameters(getClass).map(_._1) + val fieldNames = getConstructorParameterNames(getClass) val fieldValues = productIterator.toSeq ++ otherCopyArgs assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) @@ -560,7 +560,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper case p: Product => try { - val fieldNames = getConstructorParameters(p.getClass).map(_._1) + val fieldNames = getConstructorParameterNames(p.getClass) val fieldValues = p.productIterator.toSeq assert(fieldNames.length == fieldValues.length) ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { From 8d3cc3de7d116190911e7943ef3233fe3b7db1bf Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Thu, 28 Jan 2016 23:34:50 -0800 Subject: [PATCH 058/131] [SPARK-13050][BUILD] Scalatest tags fail build with the addition of the sketch module A dependency on the spark test tags was left out of the sketch module pom file causing builds to fail when test tags were used. This dependency is found in the pom file for every other module in spark. Author: Alex Bozarth Closes #10954 from ajbozarth/spark13050. --- common/sketch/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 67723fa421ab..2cafe8c548f5 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -35,6 +35,13 @@ sketch + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes From 55561e7693dd2a5bf3c7f8026c725421801fd0ec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 01:59:59 -0800 Subject: [PATCH 059/131] [SPARK-13031][SQL] cleanup codegen and improve test coverage 1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. This PR re-open #10944 and fix the bug. Author: Davies Liu Closes #10977 from davies/gen_refactor. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++++++++------ .../aggregate/AggregationIterator.scala | 2 +- .../aggregate/TungstenAggregate.scala | 98 ++++++--- .../spark/sql/execution/basicOperators.scala | 96 +++++---- .../spark/sql/DataFrameAggregateSuite.scala | 7 + .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 11 files changed, 350 insertions(+), 205 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad37..e6704cf8bb1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,14 +144,23 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = { + if (freshNamePrefix == "") { + s"$name${curId.getAndIncrement}" + } else { + s"${freshNamePrefix}_$name${curId.getAndIncrement}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6e..ec31db19b94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de980..ef81ba60f049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns the RDD of InternalRow which generates the input rows. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def upstream(): RDD[InternalRow] + + /** + * Returns Java source code to process the rows from upstream. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent + ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) } + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + ctx.freshNamePrefix = nodeName + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, evals)} + """.stripMargin + } else { + doConsume(ctx, input) + } + } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doPrepare(): Unit = { + child.prepare() + } - override def supportCodegen: Boolean = true + override def doExecute(): RDD[InternalRow] = { + child.execute() + } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def supportCodegen: Boolean = false + + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" - | while (input.hasNext()) { + s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() --------> produce() + * doExecute() ---------> upstream() -------> upstream() ------> execute() + * | + * -----------------> produce() * | * doProduce() -------> produce() * | - * doProduce() ---> execute() + * doProduce() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| * | - * doConsume() <----- consume() + * doConsume() + * | + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) + val code = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() { + protected void processNext() throws java.io.IOException { $code - } + } } - """ + """ + // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - rdd.mapPartitions { iter => + plan.upstream().mapPartitions { iter => + val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | currentRow = unsafeRow; + | currentRow = $row; | return; """.stripMargin + } else { + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } } } @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 0c74df0aa5fd..38da82c47ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -238,7 +238,7 @@ abstract class AggregationIterator( resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { - // Grouping-only: we only output values of grouping expressions. + // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { resultProjection(currentGroupingKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d25..ff2f38bfd910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,9 +117,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } // The variables used as aggregation buffer @@ -127,7 +125,11 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -137,60 +139,96 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = + // generate variables for output + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.gen(ctx)) + (resultVars, resultVars.map(_.code).mkString("\n")) + } + + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, s""" - | if (!$initAgg) { - | $initAgg = true; - | + | private void $doAgg() { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } - """.stripMargin + """.stripMargin) - (rdd, source) + s""" + | if (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | $genResult + | + | ${consume(ctx, resultVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" - | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; | ${bufVars(i).value} = ${ev.value}; """.stripMargin } s""" - | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | // do aggregate + | ${aggVals.map(_.code).mkString("\n")} + | // update aggregation buffer + | ${updates.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5e..e7a73d5fbb4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -153,17 +161,21 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { - val initTerm = ctx.freshName("range_initRange") + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("range_partitionEnd") + val partitionEnd = ctx.freshName("partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("range_number") + val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("range_overflow") + val overflow = ctx.freshName("overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("range_value") + val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -172,38 +184,42 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } + """.stripMargin) - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } + | initRange(((InternalRow) input.next()).getInt(0)); | } else { | return; | } @@ -218,12 +234,6 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b1004bc5bc29..08fb7c9d84c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -153,6 +153,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("agg without groups and functions") { + checkAnswer( + testData2.agg(lit(1)), + Row(1) + ) + } + test("average") { checkAnswer( testData2.agg(avg('a), mean('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 989cb2942918..51a50c1fa30e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } - - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cbae19ebd269..82f6811503c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48143762cac..7d6bff8295d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9a24a2487a25..a3e5243b68ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From e51b6eaa9e9c007e194d858195291b2b9fb27322 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 29 Jan 2016 09:22:24 -0800 Subject: [PATCH 060/131] [SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example * Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang Author: Joseph K. Bradley Closes #10469 from yanboliang/spark-11939. --- python/pyspark/ml/param/__init__.py | 24 +++++ python/pyspark/ml/regression.py | 30 +++++- python/pyspark/ml/tests.py | 36 +++++-- python/pyspark/ml/util.py | 142 +++++++++++++++++++++++++++- python/pyspark/ml/wrapper.py | 33 ++++--- 5 files changed, 236 insertions(+), 29 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 3da36d32c5af..ea86d6aeb8b3 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -314,3 +314,27 @@ def _copyValues(self, to, extra=None): if p in paramMap and to.hasParam(p.name): to._set(**{p.name: paramMap[p]}) return to + + def _resetUid(self, newUid): + """ + Changes the uid of this instance. This updates both + the stored uid and the parent uid of params and param maps. + This is used by persistence (loading). + :param newUid: new uid to use + :return: same instance, but with the uid and Param.parent values + updated, including within param maps + """ + self.uid = newUid + newDefaultParamMap = dict() + newParamMap = dict() + for param in self.params: + newParam = copy.copy(param) + newParam.parent = newUid + if param in self._defaultParamMap: + newDefaultParamMap[newParam] = self._defaultParamMap[param] + if param in self._paramMap: + newParamMap[newParam] = self._paramMap[param] + param.parent = newUid + self._defaultParamMap = newDefaultParamMap + self._paramMap = newParamMap + return self diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 74a2248ed07c..20dc6c2db91f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -18,9 +18,9 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.mllib.common import inherit_doc @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol): + HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): """ Linear regression. @@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lr_path = path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LinearRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LinearRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -106,7 +125,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel): +class LinearRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by LinearRegression. @@ -821,9 +840,10 @@ def predict(self, features): if __name__ == "__main__": import doctest + import pyspark.ml.regression from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.regression tests") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c45a159c460f..54806ee33666 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -34,18 +34,22 @@ else: import unittest -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext, Row -from pyspark.sql.functions import rand +from shutil import rmtree +import tempfile + +from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.feature import * from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.util import keyword_only -from pyspark.ml import Estimator, Model, Pipeline, Transformer -from pyspark.ml.feature import * +from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel +from pyspark.ml.util import keyword_only from pyspark.mllib.linalg import DenseVector +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase class MockDataset(DataFrame): @@ -405,6 +409,26 @@ def test_fit_maximize_metric(self): self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") +class PersistenceTest(PySparkTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index cee9d67b0532..d7a813f56cd5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,8 +15,27 @@ # limitations under the License. # -from functools import wraps +import sys import uuid +from functools import wraps + +if sys.version > '3': + basestring = str + +from pyspark import SparkContext, since +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") def keyword_only(func): @@ -52,3 +71,124 @@ def _randomUID(cls): concatenates the class name, "_", and 12 random hex chars. """ return cls.__name__ + "_" + uuid.uuid4().hex[12:] + + +@inherit_doc +class JavaMLWriter(object): + """ + .. note:: Experimental + + Utility class that can save ML instances through their Scala implementation. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, instance): + instance._transfer_params_to_java() + self._jwrite = instance._java_obj.write() + + def save(self, path): + """Save the ML instance to the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._jwrite.save(path) + + def overwrite(self): + """Overwrites if the output path already exists.""" + self._jwrite.overwrite() + return self + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + self._jwrite.context(sqlContext._ssql_ctx) + return self + + +@inherit_doc +class MLWritable(object): + """ + .. note:: Experimental + + Mixin for ML instances that provide JavaMLWriter. + + .. versionadded:: 2.0.0 + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + +@inherit_doc +class JavaMLReader(object): + """ + .. note:: Experimental + + Utility class that can load ML instances through their Scala implementation. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, clazz): + self._clazz = clazz + self._jread = self._load_java_obj(clazz).read() + + def load(self, path): + """Load the ML instance from the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_obj = self._jread.load(path) + instance = self._clazz() + instance._java_obj = java_obj + instance._resetUid(java_obj.uid()) + instance._transfer_params_from_java() + return instance + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + self._jread.context(sqlContext._ssql_ctx) + return self + + @classmethod + def _java_loader_class(cls, clazz): + """ + Returns the full class name of the Java ML instance. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = clazz.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, clazz.__name__]) + + @classmethod + def _load_java_obj(cls, clazz): + """Load the peer Java object of the ML instance.""" + java_class = cls._java_loader_class(clazz) + java_obj = _jvm() + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj + + +@inherit_doc +class MLReadable(object): + """ + .. note:: Experimental + + Mixin for instances that provide JavaMLReader. + + .. versionadded:: 2.0.0 + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def load(cls, path): + """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" + return cls.read().load(path) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index dd1d4b076edd..d4d48eb2150e 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,21 +21,10 @@ from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -def _jvm(): - """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. - """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") - - @inherit_doc class JavaWrapper(Params): """ @@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer): __metaclass__ = ABCMeta - def __init__(self, java_model): + def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. + + This instance can be instantiated without specifying java_model, + it will be assigned after that, but this scenario only used by + :py:class:`JavaMLReader` to load models. This is a bit of a + hack, but it is easiest since a proper fix would require + MLReader (in pyspark.ml.util) to depend on these wrappers, but + these wrappers depend on pyspark.ml.util (both directly and via + other ML classes). """ super(JavaModel, self).__init__() - self._java_obj = java_model - self.uid = java_model.uid() + if java_model is not None: + self._java_obj = java_model + self.uid = java_model.uid() def copy(self, extra=None): """ @@ -182,8 +180,9 @@ def copy(self, extra=None): if extra is None: extra = dict() that = super(JavaModel, self).copy(extra) - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that def _call_java(self, name, *args): From e4c1162b6b3dbc8fc95cfe75c6e0bc2915575fb2 Mon Sep 17 00:00:00 2001 From: zhuol Date: Fri, 29 Jan 2016 11:54:58 -0600 Subject: [PATCH 061/131] [SPARK-10873] Support column sort and search for History Server. [SPARK-10873] Support column sort and search for History Server using jQuery DataTable and REST API. Before this commit, the history server was generated hard-coded html and can not support search, also, the sorting was disabled if there is any application that has more than one attempt. Supporting search and sort (over all applications rather than the 20 entries in the current page) in any case will greatly improve user experience. 1. Create the historypage-template.html for displaying application information in datables. 2. historypage.js uses jQuery to access the data from /api/v1/applications REST API, and use DataTable to display each application's information. For application that has more than one attempt, the RowsGroup is used to merge such entries while at the same time supporting sort and search. 3. "duration" and "lastUpdated" rest API are added to application's "attempts". 4. External javascirpt and css files for datatables, RowsGroup and jquery plugins are added with licenses clarified. Snapshots for how it looks like now: History page view: ![historypage](https://cloud.githubusercontent.com/assets/11683054/12184383/89bad774-b55a-11e5-84e4-b0276172976f.png) Search: ![search](https://cloud.githubusercontent.com/assets/11683054/12184385/8d3b94b0-b55a-11e5-869a-cc0ef0a4242a.png) Sort by started time: ![sort-by-started-time](https://cloud.githubusercontent.com/assets/11683054/12184387/8f757c3c-b55a-11e5-98c8-577936366566.png) Author: zhuol Closes #10648 from zhuoliu/10873. --- .rat-excludes | 10 + LICENSE | 6 + .../spark/ui/static/dataTables.bootstrap.css | 319 ++++++++++ .../ui/static/dataTables.bootstrap.min.js | 8 + .../spark/ui/static/dataTables.rowsGroup.js | 224 +++++++ .../spark/ui/static/historypage-template.html | 81 +++ .../org/apache/spark/ui/static/historypage.js | 159 +++++ .../spark/ui/static/jquery.blockUI.min.js | 6 + .../ui/static/jquery.cookies.2.2.0.min.js | 18 + .../static/jquery.dataTables.1.10.4.min.css | 1 + .../ui/static/jquery.dataTables.1.10.4.min.js | 157 +++++ .../apache/spark/ui/static/jquery.mustache.js | 592 ++++++++++++++++++ .../spark/ui/static/jsonFormatter.min.css | 1 + .../spark/ui/static/jsonFormatter.min.js | 2 + .../spark/deploy/history/HistoryPage.scala | 193 +----- .../api/v1/ApplicationListResource.scala | 14 + .../org/apache/spark/status/api/v1/api.scala | 2 + .../scala/org/apache/spark/ui/SparkUI.scala | 2 + .../scala/org/apache/spark/ui/UIUtils.scala | 11 + .../application_list_json_expectation.json | 18 +- .../completed_app_list_json_expectation.json | 18 +- .../maxDate2_app_list_json_expectation.json | 4 +- .../maxDate_app_list_json_expectation.json | 6 +- .../minDate_app_list_json_expectation.json | 14 +- .../one_app_json_expectation.json | 4 +- ...ne_app_multi_attempt_json_expectation.json | 6 +- .../deploy/history/HistoryServerSuite.scala | 43 +- project/MimaExcludes.scala | 4 + 28 files changed, 1721 insertions(+), 202 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage-template.html create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js create mode 100755 core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.css create mode 100755 core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js diff --git a/.rat-excludes b/.rat-excludes index a4f316a4aaa0..874a6ee9f404 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,16 @@ graphlib-dot.min.js sorttable.js vis.min.js vis.min.css +dataTables.bootstrap.css +dataTables.bootstrap.min.js +dataTables.rowsGroup.js +jquery.blockUI.min.js +jquery.cookies.2.2.0.min.js +jquery.dataTables.1.10.4.min.css +jquery.dataTables.1.10.4.min.js +jquery.mustache.js +jsonFormatter.min.css +jsonFormatter.min.js .*avsc .*txt .*json diff --git a/LICENSE b/LICENSE index 9c944ac610af..9fc29db8d3f2 100644 --- a/LICENSE +++ b/LICENSE @@ -291,3 +291,9 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) + (MIT License) datatables (http://datatables.net/license) + (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) + (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) + (MIT License) blockUI (http://jquery.malsup.com/block/) + (MIT License) RowsGroup (http://datatables.net/license/mit) + (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css new file mode 100644 index 000000000000..faee0e50dbfe --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css @@ -0,0 +1,319 @@ +div.dataTables_length label { + font-weight: normal; + text-align: left; + white-space: nowrap; +} + +div.dataTables_length select { + width: 75px; + display: inline-block; +} + +div.dataTables_filter { + text-align: right; +} + +div.dataTables_filter label { + font-weight: normal; + white-space: nowrap; + text-align: left; +} + +div.dataTables_filter input { + margin-left: 0.5em; + display: inline-block; +} + +div.dataTables_info { + padding-top: 8px; + white-space: nowrap; +} + +div.dataTables_paginate { + margin: 0; + white-space: nowrap; + text-align: right; +} + +div.dataTables_paginate ul.pagination { + margin: 2px 0; + white-space: nowrap; +} + +@media screen and (max-width: 767px) { + div.dataTables_length, + div.dataTables_filter, + div.dataTables_info, + div.dataTables_paginate { + text-align: center; + } +} + + +table.dataTable td, +table.dataTable th { + -webkit-box-sizing: content-box; + -moz-box-sizing: content-box; + box-sizing: content-box; +} + + +table.dataTable { + clear: both; + margin-top: 6px !important; + margin-bottom: 6px !important; + max-width: none !important; +} + +table.dataTable thead .sorting, +table.dataTable thead .sorting_asc, +table.dataTable thead .sorting_desc, +table.dataTable thead .sorting_asc_disabled, +table.dataTable thead .sorting_desc_disabled { + cursor: pointer; +} + +table.dataTable thead .sorting { background: url('../images/sort_both.png') no-repeat center right; } +table.dataTable thead .sorting_asc { background: url('../images/sort_asc.png') no-repeat center right; } +table.dataTable thead .sorting_desc { background: url('../images/sort_desc.png') no-repeat center right; } + +table.dataTable thead .sorting_asc_disabled { background: url('../images/sort_asc_disabled.png') no-repeat center right; } +table.dataTable thead .sorting_desc_disabled { background: url('../images/sort_desc_disabled.png') no-repeat center right; } + +table.dataTable thead > tr > th { + padding-left: 18px; + padding-right: 18px; +} + +table.dataTable th:active { + outline: none; +} + +/* Scrolling */ +div.dataTables_scrollHead table { + margin-bottom: 0 !important; + border-bottom-left-radius: 0; + border-bottom-right-radius: 0; +} + +div.dataTables_scrollHead table thead tr:last-child th:first-child, +div.dataTables_scrollHead table thead tr:last-child td:first-child { + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.dataTables_scrollBody table { + border-top: none; + margin-top: 0 !important; + margin-bottom: 0 !important; +} + +div.dataTables_scrollBody tbody tr:first-child th, +div.dataTables_scrollBody tbody tr:first-child td { + border-top: none; +} + +div.dataTables_scrollFoot table { + margin-top: 0 !important; + border-top: none; +} + +/* Frustratingly the border-collapse:collapse used by Bootstrap makes the column + width calculations when using scrolling impossible to align columns. We have + to use separate + */ +table.table-bordered.dataTable { + border-collapse: separate !important; +} +table.table-bordered thead th, +table.table-bordered thead td { + border-left-width: 0; + border-top-width: 0; +} +table.table-bordered tbody th, +table.table-bordered tbody td { + border-left-width: 0; + border-bottom-width: 0; +} +table.table-bordered th:last-child, +table.table-bordered td:last-child { + border-right-width: 0; +} +div.dataTables_scrollHead table.table-bordered { + border-bottom-width: 0; +} + + + + +/* + * TableTools styles + */ +.table.dataTable tbody tr.active td, +.table.dataTable tbody tr.active th { + background-color: #08C; + color: white; +} + +.table.dataTable tbody tr.active:hover td, +.table.dataTable tbody tr.active:hover th { + background-color: #0075b0 !important; +} + +.table.dataTable tbody tr.active th > a, +.table.dataTable tbody tr.active td > a { + color: white; +} + +.table-striped.dataTable tbody tr.active:nth-child(odd) td, +.table-striped.dataTable tbody tr.active:nth-child(odd) th { + background-color: #017ebc; +} + +table.DTTT_selectable tbody tr { + cursor: pointer; +} + +div.DTTT .btn { + color: #333 !important; + font-size: 12px; +} + +div.DTTT .btn:hover { + text-decoration: none !important; +} + +ul.DTTT_dropdown.dropdown-menu { + z-index: 2003; +} + +ul.DTTT_dropdown.dropdown-menu a { + color: #333 !important; /* needed only when demo_page.css is included */ +} + +ul.DTTT_dropdown.dropdown-menu li { + position: relative; +} + +ul.DTTT_dropdown.dropdown-menu li:hover a { + background-color: #0088cc; + color: white !important; +} + +div.DTTT_collection_background { + z-index: 2002; +} + +/* TableTools information display */ +div.DTTT_print_info { + position: fixed; + top: 50%; + left: 50%; + width: 400px; + height: 150px; + margin-left: -200px; + margin-top: -75px; + text-align: center; + color: #333; + padding: 10px 30px; + opacity: 0.95; + + background-color: white; + border: 1px solid rgba(0, 0, 0, 0.2); + border-radius: 6px; + + -webkit-box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5); + box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5); +} + +div.DTTT_print_info h6 { + font-weight: normal; + font-size: 28px; + line-height: 28px; + margin: 1em; +} + +div.DTTT_print_info p { + font-size: 14px; + line-height: 20px; +} + +div.dataTables_processing { + position: absolute; + top: 50%; + left: 50%; + width: 100%; + height: 60px; + margin-left: -50%; + margin-top: -25px; + padding-top: 20px; + padding-bottom: 20px; + text-align: center; + font-size: 1.2em; + background-color: white; + background: -webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0))); + background: -webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); +} + + + +/* + * FixedColumns styles + */ +div.DTFC_LeftHeadWrapper table, +div.DTFC_LeftFootWrapper table, +div.DTFC_RightHeadWrapper table, +div.DTFC_RightFootWrapper table, +table.DTFC_Cloned tr.even { + background-color: white; + margin-bottom: 0; +} + +div.DTFC_RightHeadWrapper table , +div.DTFC_LeftHeadWrapper table { + border-bottom: none !important; + margin-bottom: 0 !important; + border-top-right-radius: 0 !important; + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.DTFC_RightHeadWrapper table thead tr:last-child th:first-child, +div.DTFC_RightHeadWrapper table thead tr:last-child td:first-child, +div.DTFC_LeftHeadWrapper table thead tr:last-child th:first-child, +div.DTFC_LeftHeadWrapper table thead tr:last-child td:first-child { + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.DTFC_RightBodyWrapper table, +div.DTFC_LeftBodyWrapper table { + border-top: none; + margin: 0 !important; +} + +div.DTFC_RightBodyWrapper tbody tr:first-child th, +div.DTFC_RightBodyWrapper tbody tr:first-child td, +div.DTFC_LeftBodyWrapper tbody tr:first-child th, +div.DTFC_LeftBodyWrapper tbody tr:first-child td { + border-top: none; +} + +div.DTFC_RightFootWrapper table, +div.DTFC_LeftFootWrapper table { + border-top: none; + margin-top: 0 !important; +} + + +/* + * FixedHeader styles + */ +div.FixedHeader_Cloned table { + margin: 0 !important +} + diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js new file mode 100644 index 000000000000..f0d09b9d5266 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js @@ -0,0 +1,8 @@ +/*! + DataTables Bootstrap 3 integration + ©2011-2014 SpryMedia Ltd - datatables.net/license +*/ +(function(){var f=function(c,b){c.extend(!0,b.defaults,{dom:"<'row'<'col-sm-6'l><'col-sm-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-6'i><'col-sm-6'p>>",renderer:"bootstrap"});c.extend(b.ext.classes,{sWrapper:"dataTables_wrapper form-inline dt-bootstrap",sFilterInput:"form-control input-sm",sLengthSelect:"form-control input-sm"});b.ext.renderer.pageButton.bootstrap=function(g,f,p,k,h,l){var q=new b.Api(g),r=g.oClasses,i=g.oLanguage.oPaginate,d,e,o=function(b,f){var j,m,n,a,k=function(a){a.preventDefault(); +c(a.currentTarget).hasClass("disabled")||q.page(a.data.action).draw(!1)};j=0;for(m=f.length;j",{"class":r.sPageButton+" "+ +e,"aria-controls":g.sTableId,tabindex:g.iTabIndex,id:0===p&&"string"===typeof a?g.sTableId+"_"+a:null}).append(c("",{href:"#"}).html(d)).appendTo(b),g.oApi._fnBindAction(n,{action:a},k))}};o(c(f).empty().html('
      ').children("ul"),k)};b.TableTools&&(c.extend(!0,b.TableTools.classes,{container:"DTTT btn-group",buttons:{normal:"btn btn-default",disabled:"disabled"},collection:{container:"DTTT_dropdown dropdown-menu",buttons:{normal:"",disabled:"disabled"}},print:{info:"DTTT_print_info"}, +select:{row:"active"}}),c.extend(!0,b.TableTools.DEFAULTS.oTags,{collection:{container:"ul",button:"li",liner:"a"}}))};"function"===typeof define&&define.amd?define(["jquery","datatables"],f):"object"===typeof exports?f(require("jquery"),require("datatables")):jQuery&&f(jQuery,jQuery.fn.dataTable)})(window,document); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js new file mode 100644 index 000000000000..983c3a564fb1 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js @@ -0,0 +1,224 @@ +/*! RowsGroup for DataTables v1.0.0 + * 2015 Alexey Shildyakov ashl1future@gmail.com + */ + +/** + * @summary RowsGroup + * @description Group rows by specified columns + * @version 1.0.0 + * @file dataTables.rowsGroup.js + * @author Alexey Shildyakov (ashl1future@gmail.com) + * @contact ashl1future@gmail.com + * @copyright Alexey Shildyakov + * + * License MIT - http://datatables.net/license/mit + * + * This feature plug-in for DataTables automatically merges columns cells + * based on it's values equality. It supports multi-column row grouping + * in according to the requested order with dependency from each previous + * requested columns. Now it supports ordering and searching. + * Please see the example.html for details. + * + * Rows grouping in DataTables can be enabled by using any one of the following + * options: + * + * * Setting the `rowsGroup` parameter in the DataTables initialisation + * to array which contains columns selectors + * (https://datatables.net/reference/type/column-selector) used for grouping. i.e. + * rowsGroup = [1, 'columnName:name', ] + * * Setting the `rowsGroup` parameter in the DataTables defaults + * (thus causing all tables to have this feature) - i.e. + * `$.fn.dataTable.defaults.RowsGroup = [0]`. + * * Creating a new instance: `new $.fn.dataTable.RowsGroup( table, columnsForGrouping );` + * where `table` is a DataTable's API instance and `columnsForGrouping` is the array + * described above. + * + * For more detailed information please see: + * + */ + +(function($){ + +ShowedDataSelectorModifier = { + order: 'current', + page: 'current', + search: 'applied', +} + +GroupedColumnsOrderDir = 'desc'; // change + + +/* + * columnsForGrouping: array of DTAPI:cell-selector for columns for which rows grouping is applied + */ +var RowsGroup = function ( dt, columnsForGrouping ) +{ + this.table = dt.table(); + this.columnsForGrouping = columnsForGrouping; + // set to True when new reorder is applied by RowsGroup to prevent order() looping + this.orderOverrideNow = false; + this.order = [] + + self = this; + $(document).on('order.dt', function ( e, settings) { + if (!self.orderOverrideNow) { + self._updateOrderAndDraw() + } + self.orderOverrideNow = false; + }) + + $(document).on('draw.dt', function ( e, settings) { + self._mergeCells() + }) + + this._updateOrderAndDraw(); +}; + + +RowsGroup.prototype = { + _getOrderWithGroupColumns: function (order, groupedColumnsOrderDir) + { + if (groupedColumnsOrderDir === undefined) + groupedColumnsOrderDir = GroupedColumnsOrderDir + + var self = this; + var groupedColumnsIndexes = this.columnsForGrouping.map(function(columnSelector){ + return self.table.column(columnSelector).index() + }) + var groupedColumnsKnownOrder = order.filter(function(columnOrder){ + return groupedColumnsIndexes.indexOf(columnOrder[0]) >= 0 + }) + var nongroupedColumnsOrder = order.filter(function(columnOrder){ + return groupedColumnsIndexes.indexOf(columnOrder[0]) < 0 + }) + var groupedColumnsKnownOrderIndexes = groupedColumnsKnownOrder.map(function(columnOrder){ + return columnOrder[0] + }) + var groupedColumnsOrder = groupedColumnsIndexes.map(function(iColumn){ + var iInOrderIndexes = groupedColumnsKnownOrderIndexes.indexOf(iColumn) + if (iInOrderIndexes >= 0) + return [iColumn, groupedColumnsKnownOrder[iInOrderIndexes][1]] + else + return [iColumn, groupedColumnsOrderDir] + }) + + groupedColumnsOrder.push.apply(groupedColumnsOrder, nongroupedColumnsOrder) + return groupedColumnsOrder; + }, + + // Workaround: the DT reset ordering to 'desc' from multi-ordering if user order on one column (without shift) + // but because we always has multi-ordering due to grouped rows this happens every time + _getInjectedMonoSelectWorkaround: function(order) + { + if (order.length === 1) { + // got mono order - workaround here + var orderingColumn = order[0][0] + var previousOrder = this.order.map(function(val){ + return val[0] + }) + var iColumn = previousOrder.indexOf(orderingColumn); + if (iColumn >= 0) { + // assume change the direction, because we already has that in previous order + return [[orderingColumn, this._toogleDirection(this.order[iColumn][1])]] + } // else This is the new ordering column. Proceed as is. + } // else got multi order - work normal + return order; + }, + + _mergeCells: function() + { + var columnsIndexes = this.table.columns(this.columnsForGrouping, ShowedDataSelectorModifier).indexes().toArray() + var showedRowsCount = this.table.rows(ShowedDataSelectorModifier)[0].length + this._mergeColumn(0, showedRowsCount - 1, columnsIndexes) + }, + + // the index is relative to the showed data + // (selector-modifier = {order: 'current', page: 'current', search: 'applied'}) index + _mergeColumn: function(iStartRow, iFinishRow, columnsIndexes) + { + var columnsIndexesCopy = columnsIndexes.slice() + currentColumn = columnsIndexesCopy.shift() + currentColumn = this.table.column(currentColumn, ShowedDataSelectorModifier) + + var columnNodes = currentColumn.nodes() + var columnValues = currentColumn.data() + + var newSequenceRow = iStartRow, + iRow; + for (iRow = iStartRow + 1; iRow <= iFinishRow; ++iRow) { + + if (columnValues[iRow] === columnValues[newSequenceRow]) { + $(columnNodes[iRow]).hide() + } else { + $(columnNodes[newSequenceRow]).show() + $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1) - newSequenceRow + 1) + + if (columnsIndexesCopy.length > 0) + this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy) + + newSequenceRow = iRow; + } + + } + $(columnNodes[newSequenceRow]).show() + $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1)- newSequenceRow + 1) + if (columnsIndexesCopy.length > 0) + this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy) + }, + + _toogleDirection: function(dir) + { + return dir == 'asc'? 'desc': 'asc'; + }, + + _updateOrderAndDraw: function() + { + this.orderOverrideNow = true; + + var currentOrder = this.table.order(); + currentOrder = this._getInjectedMonoSelectWorkaround(currentOrder); + this.order = this._getOrderWithGroupColumns(currentOrder) + // this.table.order($.extend(true, Array(), this.order)) // disable this line in order to support sorting on non-grouped columns + this.table.draw(false) + }, +}; + + +$.fn.dataTable.RowsGroup = RowsGroup; +$.fn.DataTable.RowsGroup = RowsGroup; + +// Automatic initialisation listener +$(document).on( 'init.dt', function ( e, settings ) { + if ( e.namespace !== 'dt' ) { + return; + } + + var api = new $.fn.dataTable.Api( settings ); + + if ( settings.oInit.rowsGroup || + $.fn.dataTable.defaults.rowsGroup ) + { + options = settings.oInit.rowsGroup? + settings.oInit.rowsGroup: + $.fn.dataTable.defaults.rowsGroup; + new RowsGroup( api, options ); + } +} ); + +}(jQuery)); + +/* + +TODO: Provide function which determines the all s and s with "rowspan" html-attribute is parent (groupped) for the specified or . To use in selections, editing or hover styles. + +TODO: Feature +Use saved order direction for grouped columns + Split the columns into grouped and ungrouped. + + user = grouped+ungrouped + grouped = grouped + saved = grouped+ungrouped + + For grouped uses following order: user -> saved (because 'saved' include 'grouped' after first initialisation). This should be done with saving order like for 'groupedColumns' + For ungrouped: uses only 'user' input ordering +*/ \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html new file mode 100644 index 000000000000..66d111e43909 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -0,0 +1,81 @@ + + + diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js new file mode 100644 index 000000000000..785abe45bc56 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// this function works exactly the same as UIUtils.formatDuration +function formatDuration(milliseconds) { + if (milliseconds < 100) { + return milliseconds + " ms"; + } + var seconds = milliseconds * 1.0 / 1000; + if (seconds < 1) { + return seconds.toFixed(1) + " s"; + } + if (seconds < 60) { + return seconds.toFixed(0) + " s"; + } + var minutes = seconds / 60; + if (minutes < 10) { + return minutes.toFixed(1) + " min"; + } else if (minutes < 60) { + return minutes.toFixed(0) + " min"; + } + var hours = minutes / 60; + return hours.toFixed(1) + " h"; +} + +function formatDate(date) { + return date.split(".")[0].replace("T", " "); +} + +function getParameterByName(name, searchString) { + var regex = new RegExp("[\\?&]" + name + "=([^&#]*)"), + results = regex.exec(searchString); + return results === null ? "" : decodeURIComponent(results[1].replace(/\+/g, " ")); +} + +jQuery.extend( jQuery.fn.dataTableExt.oSort, { + "title-numeric-pre": function ( a ) { + var x = a.match(/title="*(-?[0-9\.]+)/)[1]; + return parseFloat( x ); + }, + + "title-numeric-asc": function ( a, b ) { + return ((a < b) ? -1 : ((a > b) ? 1 : 0)); + }, + + "title-numeric-desc": function ( a, b ) { + return ((a < b) ? 1 : ((a > b) ? -1 : 0)); + } +} ); + +$(document).ajaxStop($.unblockUI); +$(document).ajaxStart(function(){ + $.blockUI({ message: '

      Loading history summary...

      '}); +}); + +$(document).ready(function() { + $.extend( $.fn.dataTable.defaults, { + stateSave: true, + lengthMenu: [[20,40,60,100,-1], [20, 40, 60, 100, "All"]], + pageLength: 20 + }); + + historySummary = $("#history-summary"); + searchString = historySummary["context"]["location"]["search"]; + requestedIncomplete = getParameterByName("showIncomplete", searchString); + requestedIncomplete = (requestedIncomplete == "true" ? true : false); + + $.getJSON("/api/v1/applications", function(response,status,jqXHR) { + var array = []; + var hasMultipleAttempts = false; + for (i in response) { + var app = response[i]; + if (app["attempts"][0]["completed"] == requestedIncomplete) { + continue; // if we want to show for Incomplete, we skip the completed apps; otherwise skip incomplete ones. + } + var id = app["id"]; + var name = app["name"]; + if (app["attempts"].length > 1) { + hasMultipleAttempts = true; + } + var num = app["attempts"].length; + for (j in app["attempts"]) { + var attempt = app["attempts"][j]; + attempt["startTime"] = formatDate(attempt["startTime"]); + attempt["endTime"] = formatDate(attempt["endTime"]); + attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; + array.push(app_clone); + } + } + + var data = {"applications": array} + $.get("/static/historypage-template.html", function(template) { + historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); + var selector = "#history-summary-table"; + var conf = { + "columns": [ + {name: 'first'}, + {name: 'second'}, + {name: 'third'}, + {name: 'fourth'}, + {name: 'fifth'}, + {name: 'sixth', type: "title-numeric"}, + {name: 'seventh'}, + {name: 'eighth'}, + ], + }; + + var rowGroupConf = { + "rowsGroup": [ + 'first:name', + 'second:name' + ], + }; + + if (hasMultipleAttempts) { + jQuery.extend(conf, rowGroupConf); + var rowGroupCells = document.getElementsByClassName("rowGroupColumn"); + for (i = 0; i < rowGroupCells.length; i++) { + rowGroupCells[i].style='background-color: #ffffff'; + } + } + + if (!hasMultipleAttempts) { + var attemptIDCells = document.getElementsByClassName("attemptIDSpan"); + for (i = 0; i < attemptIDCells.length; i++) { + attemptIDCells[i].style.display='none'; + } + } + + var durationCells = document.getElementsByClassName("durationClass"); + for (i = 0; i < durationCells.length; i++) { + var timeInMilliseconds = parseInt(durationCells[i].title); + durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + } + + if ($(selector.concat(" tr")).length < 20) { + $.extend(conf, {paging: false}); + } + + $(selector).DataTable(conf); + $('#hisotry-summary [data-toggle="tooltip"]').tooltip(); + }); + }); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js new file mode 100644 index 000000000000..1e84b3ec21c4 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js @@ -0,0 +1,6 @@ +/* +* jQuery BlockUI; v20131009 +* http://jquery.malsup.com/block/ +* Copyright (c) 2013 M. Alsup; Dual licensed: MIT/GPL +*/ +(function(){"use strict";function e(e){function o(o,i){var s,h,k=o==window,v=i&&void 0!==i.message?i.message:void 0;if(i=e.extend({},e.blockUI.defaults,i||{}),!i.ignoreIfBlocked||!e(o).data("blockUI.isBlocked")){if(i.overlayCSS=e.extend({},e.blockUI.defaults.overlayCSS,i.overlayCSS||{}),s=e.extend({},e.blockUI.defaults.css,i.css||{}),i.onOverlayClick&&(i.overlayCSS.cursor="pointer"),h=e.extend({},e.blockUI.defaults.themedCSS,i.themedCSS||{}),v=void 0===v?i.message:v,k&&b&&t(window,{fadeOut:0}),v&&"string"!=typeof v&&(v.parentNode||v.jquery)){var y=v.jquery?v[0]:v,m={};e(o).data("blockUI.history",m),m.el=y,m.parent=y.parentNode,m.display=y.style.display,m.position=y.style.position,m.parent&&m.parent.removeChild(y)}e(o).data("blockUI.onUnblock",i.onUnblock);var g,I,w,U,x=i.baseZ;g=r||i.forceIframe?e(''):e(''),I=i.theme?e(''):e(''),i.theme&&k?(U='"):i.theme?(U='"):U=k?'':'',w=e(U),v&&(i.theme?(w.css(h),w.addClass("ui-widget-content")):w.css(s)),i.theme||I.css(i.overlayCSS),I.css("position",k?"fixed":"absolute"),(r||i.forceIframe)&&g.css("opacity",0);var C=[g,I,w],S=k?e("body"):e(o);e.each(C,function(){this.appendTo(S)}),i.theme&&i.draggable&&e.fn.draggable&&w.draggable({handle:".ui-dialog-titlebar",cancel:"li"});var O=f&&(!e.support.boxModel||e("object,embed",k?null:o).length>0);if(u||O){if(k&&i.allowBodyStretch&&e.support.boxModel&&e("html,body").css("height","100%"),(u||!e.support.boxModel)&&!k)var E=d(o,"borderTopWidth"),T=d(o,"borderLeftWidth"),M=E?"(0 - "+E+")":0,B=T?"(0 - "+T+")":0;e.each(C,function(e,o){var t=o[0].style;if(t.position="absolute",2>e)k?t.setExpression("height","Math.max(document.body.scrollHeight, document.body.offsetHeight) - (jQuery.support.boxModel?0:"+i.quirksmodeOffsetHack+') + "px"'):t.setExpression("height",'this.parentNode.offsetHeight + "px"'),k?t.setExpression("width",'jQuery.support.boxModel && document.documentElement.clientWidth || document.body.clientWidth + "px"'):t.setExpression("width",'this.parentNode.offsetWidth + "px"'),B&&t.setExpression("left",B),M&&t.setExpression("top",M);else if(i.centerY)k&&t.setExpression("top",'(document.documentElement.clientHeight || document.body.clientHeight) / 2 - (this.offsetHeight / 2) + (blah = document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "px"'),t.marginTop=0;else if(!i.centerY&&k){var n=i.css&&i.css.top?parseInt(i.css.top,10):0,s="((document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "+n+') + "px"';t.setExpression("top",s)}})}if(v&&(i.theme?w.find(".ui-widget-content").append(v):w.append(v),(v.jquery||v.nodeType)&&e(v).show()),(r||i.forceIframe)&&i.showOverlay&&g.show(),i.fadeIn){var j=i.onBlock?i.onBlock:c,H=i.showOverlay&&!v?j:c,z=v?j:c;i.showOverlay&&I._fadeIn(i.fadeIn,H),v&&w._fadeIn(i.fadeIn,z)}else i.showOverlay&&I.show(),v&&w.show(),i.onBlock&&i.onBlock();if(n(1,o,i),k?(b=w[0],p=e(i.focusableElements,b),i.focusInput&&setTimeout(l,20)):a(w[0],i.centerX,i.centerY),i.timeout){var W=setTimeout(function(){k?e.unblockUI(i):e(o).unblock(i)},i.timeout);e(o).data("blockUI.timeout",W)}}}function t(o,t){var s,l=o==window,a=e(o),d=a.data("blockUI.history"),c=a.data("blockUI.timeout");c&&(clearTimeout(c),a.removeData("blockUI.timeout")),t=e.extend({},e.blockUI.defaults,t||{}),n(0,o,t),null===t.onUnblock&&(t.onUnblock=a.data("blockUI.onUnblock"),a.removeData("blockUI.onUnblock"));var r;r=l?e("body").children().filter(".blockUI").add("body > .blockUI"):a.find(">.blockUI"),t.cursorReset&&(r.length>1&&(r[1].style.cursor=t.cursorReset),r.length>2&&(r[2].style.cursor=t.cursorReset)),l&&(b=p=null),t.fadeOut?(s=r.length,r.stop().fadeOut(t.fadeOut,function(){0===--s&&i(r,d,t,o)})):i(r,d,t,o)}function i(o,t,i,n){var s=e(n);if(!s.data("blockUI.isBlocked")){o.each(function(){this.parentNode&&this.parentNode.removeChild(this)}),t&&t.el&&(t.el.style.display=t.display,t.el.style.position=t.position,t.parent&&t.parent.appendChild(t.el),s.removeData("blockUI.history")),s.data("blockUI.static")&&s.css("position","static"),"function"==typeof i.onUnblock&&i.onUnblock(n,i);var l=e(document.body),a=l.width(),d=l[0].style.width;l.width(a-1).width(a),l[0].style.width=d}}function n(o,t,i){var n=t==window,l=e(t);if((o||(!n||b)&&(n||l.data("blockUI.isBlocked")))&&(l.data("blockUI.isBlocked",o),n&&i.bindEvents&&(!o||i.showOverlay))){var a="mousedown mouseup keydown keypress keyup touchstart touchend touchmove";o?e(document).bind(a,i,s):e(document).unbind(a,s)}}function s(o){if("keydown"===o.type&&o.keyCode&&9==o.keyCode&&b&&o.data.constrainTabKey){var t=p,i=!o.shiftKey&&o.target===t[t.length-1],n=o.shiftKey&&o.target===t[0];if(i||n)return setTimeout(function(){l(n)},10),!1}var s=o.data,a=e(o.target);return a.hasClass("blockOverlay")&&s.onOverlayClick&&s.onOverlayClick(o),a.parents("div."+s.blockMsgClass).length>0?!0:0===a.parents().children().filter("div.blockUI").length}function l(e){if(p){var o=p[e===!0?p.length-1:0];o&&o.focus()}}function a(e,o,t){var i=e.parentNode,n=e.style,s=(i.offsetWidth-e.offsetWidth)/2-d(i,"borderLeftWidth"),l=(i.offsetHeight-e.offsetHeight)/2-d(i,"borderTopWidth");o&&(n.left=s>0?s+"px":"0"),t&&(n.top=l>0?l+"px":"0")}function d(o,t){return parseInt(e.css(o,t),10)||0}e.fn._fadeIn=e.fn.fadeIn;var c=e.noop||function(){},r=/MSIE/.test(navigator.userAgent),u=/MSIE 6.0/.test(navigator.userAgent)&&!/MSIE 8.0/.test(navigator.userAgent);document.documentMode||0;var f=e.isFunction(document.createElement("div").style.setExpression);e.blockUI=function(e){o(window,e)},e.unblockUI=function(e){t(window,e)},e.growlUI=function(o,t,i,n){var s=e('
      ');o&&s.append("

      "+o+"

      "),t&&s.append("

      "+t+"

      "),void 0===i&&(i=3e3);var l=function(o){o=o||{},e.blockUI({message:s,fadeIn:o.fadeIn!==void 0?o.fadeIn:700,fadeOut:o.fadeOut!==void 0?o.fadeOut:1e3,timeout:o.timeout!==void 0?o.timeout:i,centerY:!1,showOverlay:!1,onUnblock:n,css:e.blockUI.defaults.growlCSS})};l(),s.css("opacity"),s.mouseover(function(){l({fadeIn:0,timeout:3e4});var o=e(".blockMsg");o.stop(),o.fadeTo(300,1)}).mouseout(function(){e(".blockMsg").fadeOut(1e3)})},e.fn.block=function(t){if(this[0]===window)return e.blockUI(t),this;var i=e.extend({},e.blockUI.defaults,t||{});return this.each(function(){var o=e(this);i.ignoreIfBlocked&&o.data("blockUI.isBlocked")||o.unblock({fadeOut:0})}),this.each(function(){"static"==e.css(this,"position")&&(this.style.position="relative",e(this).data("blockUI.static",!0)),this.style.zoom=1,o(this,t)})},e.fn.unblock=function(o){return this[0]===window?(e.unblockUI(o),this):this.each(function(){t(this,o)})},e.blockUI.version=2.66,e.blockUI.defaults={message:"

      Please wait...

      ",title:null,draggable:!0,theme:!1,css:{padding:0,margin:0,width:"30%",top:"40%",left:"35%",textAlign:"center",color:"#000",border:"3px solid #aaa",backgroundColor:"#fff",cursor:"wait"},themedCSS:{width:"30%",top:"40%",left:"35%"},overlayCSS:{backgroundColor:"#000",opacity:.6,cursor:"wait"},cursorReset:"default",growlCSS:{width:"350px",top:"10px",left:"",right:"10px",border:"none",padding:"5px",opacity:.6,cursor:"default",color:"#fff",backgroundColor:"#000","-webkit-border-radius":"10px","-moz-border-radius":"10px","border-radius":"10px"},iframeSrc:/^https/i.test(window.location.href||"")?"javascript:false":"about:blank",forceIframe:!1,baseZ:1e3,centerX:!0,centerY:!0,allowBodyStretch:!0,bindEvents:!0,constrainTabKey:!0,fadeIn:200,fadeOut:400,timeout:0,showOverlay:!0,focusInput:!0,focusableElements:":input:enabled:visible",onBlock:null,onUnblock:null,onOverlayClick:null,quirksmodeOffsetHack:4,blockMsgClass:"blockMsg",ignoreIfBlocked:!1};var b=null,p=[]}"function"==typeof define&&define.amd&&define.amd.jQuery?define(["jquery"],e):e(jQuery)})(); \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js new file mode 100644 index 000000000000..bd2dacb4eeeb --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js @@ -0,0 +1,18 @@ +/** + * Copyright (c) 2005 - 2010, James Auldridge + * All rights reserved. + * + * Licensed under the BSD, MIT, and GPL (your choice!) Licenses: + * http://code.google.com/p/cookies/wiki/License + * + */ +var jaaulde=window.jaaulde||{};jaaulde.utils=jaaulde.utils||{};jaaulde.utils.cookies=(function(){var resolveOptions,assembleOptionsString,parseCookies,constructor,defaultOptions={expiresAt:null,path:'/',domain:null,secure:false};resolveOptions=function(options){var returnValue,expireDate;if(typeof options!=='object'||options===null){returnValue=defaultOptions;}else +{returnValue={expiresAt:defaultOptions.expiresAt,path:defaultOptions.path,domain:defaultOptions.domain,secure:defaultOptions.secure};if(typeof options.expiresAt==='object'&&options.expiresAt instanceof Date){returnValue.expiresAt=options.expiresAt;}else if(typeof options.hoursToLive==='number'&&options.hoursToLive!==0){expireDate=new Date();expireDate.setTime(expireDate.getTime()+(options.hoursToLive*60*60*1000));returnValue.expiresAt=expireDate;}if(typeof options.path==='string'&&options.path!==''){returnValue.path=options.path;}if(typeof options.domain==='string'&&options.domain!==''){returnValue.domain=options.domain;}if(options.secure===true){returnValue.secure=options.secure;}}return returnValue;};assembleOptionsString=function(options){options=resolveOptions(options);return((typeof options.expiresAt==='object'&&options.expiresAt instanceof Date?'; expires='+options.expiresAt.toGMTString():'')+'; path='+options.path+(typeof options.domain==='string'?'; domain='+options.domain:'')+(options.secure===true?'; secure':''));};parseCookies=function(){var cookies={},i,pair,name,value,separated=document.cookie.split(';'),unparsedValue;for(i=0;i.sorting_1,table.dataTable.order-column tbody tr>.sorting_2,table.dataTable.order-column tbody tr>.sorting_3,table.dataTable.display tbody tr>.sorting_1,table.dataTable.display tbody tr>.sorting_2,table.dataTable.display tbody tr>.sorting_3{background-color:#f9f9f9}table.dataTable.order-column tbody tr.selected>.sorting_1,table.dataTable.order-column tbody tr.selected>.sorting_2,table.dataTable.order-column tbody tr.selected>.sorting_3,table.dataTable.display tbody tr.selected>.sorting_1,table.dataTable.display tbody tr.selected>.sorting_2,table.dataTable.display tbody tr.selected>.sorting_3{background-color:#acbad4}table.dataTable.display tbody tr.odd>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd>.sorting_1{background-color:#f1f1f1}table.dataTable.display tbody tr.odd>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd>.sorting_2{background-color:#f3f3f3}table.dataTable.display tbody tr.odd>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd>.sorting_3{background-color:#f5f5f5}table.dataTable.display tbody tr.odd.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_1{background-color:#a6b3cd}table.dataTable.display tbody tr.odd.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_2{background-color:#a7b5ce}table.dataTable.display tbody tr.odd.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_3{background-color:#a9b6d0}table.dataTable.display tbody tr.even>.sorting_1,table.dataTable.order-column.stripe tbody tr.even>.sorting_1{background-color:#f9f9f9}table.dataTable.display tbody tr.even>.sorting_2,table.dataTable.order-column.stripe tbody tr.even>.sorting_2{background-color:#fbfbfb}table.dataTable.display tbody tr.even>.sorting_3,table.dataTable.order-column.stripe tbody tr.even>.sorting_3{background-color:#fdfdfd}table.dataTable.display tbody tr.even.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_1{background-color:#acbad4}table.dataTable.display tbody tr.even.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_2{background-color:#adbbd6}table.dataTable.display tbody tr.even.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_3{background-color:#afbdd8}table.dataTable.display tbody tr:hover>.sorting_1,table.dataTable.display tbody tr.odd:hover>.sorting_1,table.dataTable.display tbody tr.even:hover>.sorting_1,table.dataTable.order-column.hover tbody tr:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_1{background-color:#eaeaea}table.dataTable.display tbody tr:hover>.sorting_2,table.dataTable.display tbody tr.odd:hover>.sorting_2,table.dataTable.display tbody tr.even:hover>.sorting_2,table.dataTable.order-column.hover tbody tr:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_2{background-color:#ebebeb}table.dataTable.display tbody tr:hover>.sorting_3,table.dataTable.display tbody tr.odd:hover>.sorting_3,table.dataTable.display tbody tr.even:hover>.sorting_3,table.dataTable.order-column.hover tbody tr:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_3{background-color:#eee}table.dataTable.display tbody tr:hover.selected>.sorting_1,table.dataTable.display tbody tr.odd:hover.selected>.sorting_1,table.dataTable.display tbody tr.even:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_1{background-color:#a1aec7}table.dataTable.display tbody tr:hover.selected>.sorting_2,table.dataTable.display tbody tr.odd:hover.selected>.sorting_2,table.dataTable.display tbody tr.even:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_2{background-color:#a2afc8}table.dataTable.display tbody tr:hover.selected>.sorting_3,table.dataTable.display tbody tr.odd:hover.selected>.sorting_3,table.dataTable.display tbody tr.even:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_3{background-color:#a4b2cb}table.dataTable.no-footer{border-bottom:1px solid #111}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.compact thead th,table.dataTable.compact thead td{padding:5px 9px}table.dataTable.compact tfoot th,table.dataTable.compact tfoot td{padding:5px 9px 3px 9px}table.dataTable.compact tbody th,table.dataTable.compact tbody td{padding:4px 5px}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable,table.dataTable th,table.dataTable td{-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box}.dataTables_wrapper{position:relative;clear:both;*zoom:1;zoom:1}.dataTables_wrapper .dataTables_length{float:left}.dataTables_wrapper .dataTables_filter{float:right;text-align:right}.dataTables_wrapper .dataTables_filter input{margin-left:0.5em}.dataTables_wrapper .dataTables_info{clear:both;float:left;padding-top:0.755em}.dataTables_wrapper .dataTables_paginate{float:right;text-align:right;padding-top:0.25em}.dataTables_wrapper .dataTables_paginate .paginate_button{box-sizing:border-box;display:inline-block;min-width:1.5em;padding:0.5em 1em;margin-left:2px;text-align:center;text-decoration:none !important;cursor:pointer;*cursor:hand;color:#333 !important;border:1px solid transparent}.dataTables_wrapper .dataTables_paginate .paginate_button.current,.dataTables_wrapper .dataTables_paginate .paginate_button.current:hover{color:#333 !important;border:1px solid #cacaca;background-color:#fff;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #fff), color-stop(100%, #dcdcdc));background:-webkit-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-moz-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-ms-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-o-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:linear-gradient(to bottom, #fff 0%, #dcdcdc 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button.disabled,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:hover,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:active{cursor:default;color:#666 !important;border:1px solid transparent;background:transparent;box-shadow:none}.dataTables_wrapper .dataTables_paginate .paginate_button:hover{color:white !important;border:1px solid #111;background-color:#585858;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #585858), color-stop(100%, #111));background:-webkit-linear-gradient(top, #585858 0%, #111 100%);background:-moz-linear-gradient(top, #585858 0%, #111 100%);background:-ms-linear-gradient(top, #585858 0%, #111 100%);background:-o-linear-gradient(top, #585858 0%, #111 100%);background:linear-gradient(to bottom, #585858 0%, #111 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button:active{outline:none;background-color:#2b2b2b;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #2b2b2b), color-stop(100%, #0c0c0c));background:-webkit-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-moz-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-ms-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-o-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:linear-gradient(to bottom, #2b2b2b 0%, #0c0c0c 100%);box-shadow:inset 0 0 3px #111}.dataTables_wrapper .dataTables_processing{position:absolute;top:50%;left:50%;width:100%;height:40px;margin-left:-50%;margin-top:-25px;padding-top:20px;text-align:center;font-size:1.2em;background-color:white;background:-webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0)));background:-webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%)}.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter,.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_processing,.dataTables_wrapper .dataTables_paginate{color:#333}.dataTables_wrapper .dataTables_scroll{clear:both}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody{*margin-top:-1px;-webkit-overflow-scrolling:touch}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody th>div.dataTables_sizing,.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody td>div.dataTables_sizing{height:0;overflow:hidden;margin:0 !important;padding:0 !important}.dataTables_wrapper.no-footer .dataTables_scrollBody{border-bottom:1px solid #111}.dataTables_wrapper.no-footer div.dataTables_scrollHead table,.dataTables_wrapper.no-footer div.dataTables_scrollBody table{border-bottom:none}.dataTables_wrapper:after{visibility:hidden;display:block;content:"";clear:both;height:0}@media screen and (max-width: 767px){.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_paginate{float:none;text-align:center}.dataTables_wrapper .dataTables_paginate{margin-top:0.5em}}@media screen and (max-width: 640px){.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter{float:none;text-align:center}.dataTables_wrapper .dataTables_filter{margin-top:0.5em}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js new file mode 100644 index 000000000000..8885017c35d0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js @@ -0,0 +1,157 @@ +/*! DataTables 1.10.4 + * ©2008-2014 SpryMedia Ltd - datatables.net/license + */ +(function(Da,P,l){var O=function(g){function V(a){var b,c,e={};g.each(a,function(d){if((b=d.match(/^([^A-Z]+?)([A-Z])/))&&-1!=="a aa ai ao as b fn i m o s ".indexOf(b[1]+" "))c=d.replace(b[0],b[2].toLowerCase()),e[c]=d,"o"===b[1]&&V(a[d])});a._hungarianMap=e}function G(a,b,c){a._hungarianMap||V(a);var e;g.each(b,function(d){e=a._hungarianMap[d];if(e!==l&&(c||b[e]===l))"o"===e.charAt(0)?(b[e]||(b[e]={}),g.extend(!0,b[e],b[d]),G(a[e],b[e],c)):b[e]=b[d]})}function O(a){var b=p.defaults.oLanguage,c=a.sZeroRecords; +!a.sEmptyTable&&(c&&"No data available in table"===b.sEmptyTable)&&D(a,a,"sZeroRecords","sEmptyTable");!a.sLoadingRecords&&(c&&"Loading..."===b.sLoadingRecords)&&D(a,a,"sZeroRecords","sLoadingRecords");a.sInfoThousands&&(a.sThousands=a.sInfoThousands);(a=a.sDecimal)&&cb(a)}function db(a){z(a,"ordering","bSort");z(a,"orderMulti","bSortMulti");z(a,"orderClasses","bSortClasses");z(a,"orderCellsTop","bSortCellsTop");z(a,"order","aaSorting");z(a,"orderFixed","aaSortingFixed");z(a,"paging","bPaginate"); +z(a,"pagingType","sPaginationType");z(a,"pageLength","iDisplayLength");z(a,"searching","bFilter");if(a=a.aoSearchCols)for(var b=0,c=a.length;b").css({position:"absolute",top:0,left:0,height:1,width:1,overflow:"hidden"}).append(g("
      ").css({position:"absolute",top:1,left:1,width:100, +overflow:"scroll"}).append(g('
      ').css({width:"100%",height:10}))).appendTo("body"),c=b.find(".test");a.bScrollOversize=100===c[0].offsetWidth;a.bScrollbarLeft=1!==c.offset().left;b.remove()}function gb(a,b,c,e,d,f){var h,i=!1;c!==l&&(h=c,i=!0);for(;e!==d;)a.hasOwnProperty(e)&&(h=i?b(h,a[e],e,a):a[e],i=!0,e+=f);return h}function Ea(a,b){var c=p.defaults.column,e=a.aoColumns.length,c=g.extend({},p.models.oColumn,c,{nTh:b?b:P.createElement("th"),sTitle:c.sTitle?c.sTitle:b?b.innerHTML: +"",aDataSort:c.aDataSort?c.aDataSort:[e],mData:c.mData?c.mData:e,idx:e});a.aoColumns.push(c);c=a.aoPreSearchCols;c[e]=g.extend({},p.models.oSearch,c[e]);ja(a,e,null)}function ja(a,b,c){var b=a.aoColumns[b],e=a.oClasses,d=g(b.nTh);if(!b.sWidthOrig){b.sWidthOrig=d.attr("width")||null;var f=(d.attr("style")||"").match(/width:\s*(\d+[pxem%]+)/);f&&(b.sWidthOrig=f[1])}c!==l&&null!==c&&(eb(c),G(p.defaults.column,c),c.mDataProp!==l&&!c.mData&&(c.mData=c.mDataProp),c.sType&&(b._sManualType=c.sType),c.className&& +!c.sClass&&(c.sClass=c.className),g.extend(b,c),D(b,c,"sWidth","sWidthOrig"),"number"===typeof c.iDataSort&&(b.aDataSort=[c.iDataSort]),D(b,c,"aDataSort"));var h=b.mData,i=W(h),j=b.mRender?W(b.mRender):null,c=function(a){return"string"===typeof a&&-1!==a.indexOf("@")};b._bAttrSrc=g.isPlainObject(h)&&(c(h.sort)||c(h.type)||c(h.filter));b.fnGetData=function(a,b,c){var e=i(a,b,l,c);return j&&b?j(e,b,a,c):e};b.fnSetData=function(a,b,c){return Q(h)(a,b,c)};"number"!==typeof h&&(a._rowReadObject=!0);a.oFeatures.bSort|| +(b.bSortable=!1,d.addClass(e.sSortableNone));a=-1!==g.inArray("asc",b.asSorting);c=-1!==g.inArray("desc",b.asSorting);!b.bSortable||!a&&!c?(b.sSortingClass=e.sSortableNone,b.sSortingClassJUI=""):a&&!c?(b.sSortingClass=e.sSortableAsc,b.sSortingClassJUI=e.sSortJUIAscAllowed):!a&&c?(b.sSortingClass=e.sSortableDesc,b.sSortingClassJUI=e.sSortJUIDescAllowed):(b.sSortingClass=e.sSortable,b.sSortingClassJUI=e.sSortJUI)}function X(a){if(!1!==a.oFeatures.bAutoWidth){var b=a.aoColumns;Fa(a);for(var c=0,e=b.length;c< +e;c++)b[c].nTh.style.width=b[c].sWidth}b=a.oScroll;(""!==b.sY||""!==b.sX)&&Y(a);u(a,null,"column-sizing",[a])}function ka(a,b){var c=Z(a,"bVisible");return"number"===typeof c[b]?c[b]:null}function $(a,b){var c=Z(a,"bVisible"),c=g.inArray(b,c);return-1!==c?c:null}function aa(a){return Z(a,"bVisible").length}function Z(a,b){var c=[];g.map(a.aoColumns,function(a,d){a[b]&&c.push(d)});return c}function Ga(a){var b=a.aoColumns,c=a.aoData,e=p.ext.type.detect,d,f,h,i,j,g,m,o,k;d=0;for(f=b.length;do[f])e(m.length+o[f],n);else if("string"===typeof o[f]){i=0;for(j=m.length;ib&&a[d]--; -1!=e&&c===l&& +a.splice(e,1)}function ca(a,b,c,e){var d=a.aoData[b],f,h=function(c,f){for(;c.childNodes.length;)c.removeChild(c.firstChild);c.innerHTML=v(a,b,f,"display")};if("dom"===c||(!c||"auto"===c)&&"dom"===d.src)d._aData=ma(a,d,e,e===l?l:d._aData).data;else{var i=d.anCells;if(i)if(e!==l)h(i[e],e);else{c=0;for(f=i.length;c").appendTo(h));b=0;for(c= +m.length;btr").attr("role","row");g(h).find(">tr>th, >tr>td").addClass(n.sHeaderTH);g(i).find(">tr>th, >tr>td").addClass(n.sFooterTH);if(null!==i){a=a.aoFooter[0];b=0;for(c=a.length;b=a.fnRecordsDisplay()?0:h,a.iInitDisplayStart=-1);var h=a._iDisplayStart,n=a.fnDisplayEnd();if(a.bDeferLoading)a.bDeferLoading= +!1,a.iDraw++,B(a,!1);else if(i){if(!a.bDestroying&&!jb(a))return}else a.iDraw++;if(0!==j.length){f=i?a.aoData.length:n;for(i=i?0:h;i",{"class":d? +e[0]:""}).append(g("",{valign:"top",colSpan:aa(a),"class":a.oClasses.sRowEmpty}).html(c))[0];u(a,"aoHeaderCallback","header",[g(a.nTHead).children("tr")[0],Ka(a),h,n,j]);u(a,"aoFooterCallback","footer",[g(a.nTFoot).children("tr")[0],Ka(a),h,n,j]);e=g(a.nTBody);e.children().detach();e.append(g(b));u(a,"aoDrawCallback","draw",[a]);a.bSorted=!1;a.bFiltered=!1;a.bDrawing=!1}}function M(a,b){var c=a.oFeatures,e=c.bFilter;c.bSort&&kb(a);e?fa(a,a.oPreviousSearch):a.aiDisplay=a.aiDisplayMaster.slice(); +!0!==b&&(a._iDisplayStart=0);a._drawHold=b;L(a);a._drawHold=!1}function lb(a){var b=a.oClasses,c=g(a.nTable),c=g("
      ").insertBefore(c),e=a.oFeatures,d=g("
      ",{id:a.sTableId+"_wrapper","class":b.sWrapper+(a.nTFoot?"":" "+b.sNoFooter)});a.nHolding=c[0];a.nTableWrapper=d[0];a.nTableReinsertBefore=a.nTable.nextSibling;for(var f=a.sDom.split(""),h,i,j,n,m,o,k=0;k")[0];n=f[k+1];if("'"==n||'"'==n){m="";for(o=2;f[k+o]!=n;)m+=f[k+o],o++;"H"==m?m=b.sJUIHeader: +"F"==m&&(m=b.sJUIFooter);-1!=m.indexOf(".")?(n=m.split("."),j.id=n[0].substr(1,n[0].length-1),j.className=n[1]):"#"==m.charAt(0)?j.id=m.substr(1,m.length-1):j.className=m;k+=o}d.append(j);d=g(j)}else if(">"==i)d=d.parent();else if("l"==i&&e.bPaginate&&e.bLengthChange)h=mb(a);else if("f"==i&&e.bFilter)h=nb(a);else if("r"==i&&e.bProcessing)h=ob(a);else if("t"==i)h=pb(a);else if("i"==i&&e.bInfo)h=qb(a);else if("p"==i&&e.bPaginate)h=rb(a);else if(0!==p.ext.feature.length){j=p.ext.feature;o=0;for(n=j.length;o< +n;o++)if(i==j[o].cFeature){h=j[o].fnInit(a);break}}h&&(j=a.aanFeatures,j[i]||(j[i]=[]),j[i].push(h),d.append(h))}c.replaceWith(d)}function da(a,b){var c=g(b).children("tr"),e,d,f,h,i,j,n,m,o,k;a.splice(0,a.length);f=0;for(j=c.length;f',i=e.sSearch,i=i.match(/_INPUT_/)?i.replace("_INPUT_",h):i+h,b=g("
      ",{id:!f.f?c+"_filter":null,"class":b.sFilter}).append(g("
      ").addClass(b.sLength);a.aanFeatures.l||(j[0].id=c+"_length");j.children().append(a.oLanguage.sLengthMenu.replace("_MENU_",d[0].outerHTML));g("select",j).val(a._iDisplayLength).bind("change.DT", +function(){Qa(a,g(this).val());L(a)});g(a.nTable).bind("length.dt.DT",function(b,c,f){a===c&&g("select",j).val(f)});return j[0]}function rb(a){var b=a.sPaginationType,c=p.ext.pager[b],e="function"===typeof c,d=function(a){L(a)},b=g("
      ").addClass(a.oClasses.sPaging+b)[0],f=a.aanFeatures;e||c.fnInit(a,b,d);f.p||(b.id=a.sTableId+"_paginate",a.aoDrawCallback.push({fn:function(a){if(e){var b=a._iDisplayStart,g=a._iDisplayLength,n=a.fnRecordsDisplay(),m=-1===g,b=m?0:Math.ceil(b/g),g=m?1:Math.ceil(n/ +g),n=c(b,g),o,m=0;for(o=f.p.length;mf&&(e=0)):"first"==b?e=0:"previous"==b?(e=0<=d?e-d:0,0>e&&(e=0)):"next"==b?e+d",{id:!a.aanFeatures.r?a.sTableId+"_processing":null,"class":a.oClasses.sProcessing}).html(a.oLanguage.sProcessing).insertBefore(a.nTable)[0]}function B(a,b){a.oFeatures.bProcessing&&g(a.aanFeatures.r).css("display",b?"block":"none");u(a,null,"processing",[a,b])}function pb(a){var b=g(a.nTable);b.attr("role","grid");var c=a.oScroll;if(""===c.sX&&""===c.sY)return a.nTable;var e=c.sX,d=c.sY,f=a.oClasses,h=b.children("caption"),i=h.length?h[0]._captionSide:null, +j=g(b[0].cloneNode(!1)),n=g(b[0].cloneNode(!1)),m=b.children("tfoot");c.sX&&"100%"===b.attr("width")&&b.removeAttr("width");m.length||(m=null);c=g("
      ",{"class":f.sScrollWrapper}).append(g("
      ",{"class":f.sScrollHead}).css({overflow:"hidden",position:"relative",border:0,width:e?!e?null:s(e):"100%"}).append(g("
      ",{"class":f.sScrollHeadInner}).css({"box-sizing":"content-box",width:c.sXInner||"100%"}).append(j.removeAttr("id").css("margin-left",0).append("top"===i?h:null).append(b.children("thead"))))).append(g("
      ", +{"class":f.sScrollBody}).css({overflow:"auto",height:!d?null:s(d),width:!e?null:s(e)}).append(b));m&&c.append(g("
      ",{"class":f.sScrollFoot}).css({overflow:"hidden",border:0,width:e?!e?null:s(e):"100%"}).append(g("
      ",{"class":f.sScrollFootInner}).append(n.removeAttr("id").css("margin-left",0).append("bottom"===i?h:null).append(b.children("tfoot")))));var b=c.children(),o=b[0],f=b[1],k=m?b[2]:null;e&&g(f).scroll(function(){var a=this.scrollLeft;o.scrollLeft=a;m&&(k.scrollLeft=a)});a.nScrollHead= +o;a.nScrollBody=f;a.nScrollFoot=k;a.aoDrawCallback.push({fn:Y,sName:"scrolling"});return c[0]}function Y(a){var b=a.oScroll,c=b.sX,e=b.sXInner,d=b.sY,f=b.iBarWidth,h=g(a.nScrollHead),i=h[0].style,j=h.children("div"),n=j[0].style,m=j.children("table"),j=a.nScrollBody,o=g(j),k=j.style,l=g(a.nScrollFoot).children("div"),p=l.children("table"),r=g(a.nTHead),q=g(a.nTable),t=q[0],N=t.style,J=a.nTFoot?g(a.nTFoot):null,u=a.oBrowser,w=u.bScrollOversize,y,v,x,K,z,A=[],B=[],C=[],D,E=function(a){a=a.style;a.paddingTop= +"0";a.paddingBottom="0";a.borderTopWidth="0";a.borderBottomWidth="0";a.height=0};q.children("thead, tfoot").remove();z=r.clone().prependTo(q);y=r.find("tr");x=z.find("tr");z.find("th, td").removeAttr("tabindex");J&&(K=J.clone().prependTo(q),v=J.find("tr"),K=K.find("tr"));c||(k.width="100%",h[0].style.width="100%");g.each(pa(a,z),function(b,c){D=ka(a,b);c.style.width=a.aoColumns[D].sWidth});J&&F(function(a){a.style.width=""},K);b.bCollapse&&""!==d&&(k.height=o[0].offsetHeight+r[0].offsetHeight+"px"); +h=q.outerWidth();if(""===c){if(N.width="100%",w&&(q.find("tbody").height()>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(q.outerWidth()-f)}else""!==e?N.width=s(e):h==o.width()&&o.height()h-f&&(N.width=s(h))):N.width=s(h);h=q.outerWidth();F(E,x);F(function(a){C.push(a.innerHTML);A.push(s(g(a).css("width")))},x);F(function(a,b){a.style.width=A[b]},y);g(x).height(0);J&&(F(E,K),F(function(a){B.push(s(g(a).css("width")))},K),F(function(a,b){a.style.width= +B[b]},v),g(K).height(0));F(function(a,b){a.innerHTML='
      '+C[b]+"
      ";a.style.width=A[b]},x);J&&F(function(a,b){a.innerHTML="";a.style.width=B[b]},K);if(q.outerWidth()j.offsetHeight||"scroll"==o.css("overflow-y")?h+f:h;if(w&&(j.scrollHeight>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(v-f);(""===c||""!==e)&&R(a,1,"Possible column misalignment",6)}else v="100%";k.width=s(v);i.width=s(v);J&&(a.nScrollFoot.style.width= +s(v));!d&&w&&(k.height=s(t.offsetHeight+f));d&&b.bCollapse&&(k.height=s(d),b=c&&t.offsetWidth>j.offsetWidth?f:0,t.offsetHeightj.clientHeight||"scroll"==o.css("overflow-y");u="padding"+(u.bScrollbarLeft?"Left":"Right");n[u]=m?f+"px":"0px";J&&(p[0].style.width=s(b),l[0].style.width=s(b),l[0].style[u]=m?f+"px":"0px");o.scroll();if((a.bSorted||a.bFiltered)&&!a._drawHold)j.scrollTop=0}function F(a, +b,c){for(var e=0,d=0,f=b.length,h,g;d"));i.find("tfoot th, tfoot td").css("width","");var p=i.find("tbody tr"),j=pa(a,i.find("thead")[0]);for(k=0;k").css("width",s(a)).appendTo(b||P.body),e=c[0].offsetWidth;c.remove();return e}function Eb(a,b){var c=a.oScroll;if(c.sX||c.sY)c=!c.sX?c.iBarWidth:0,b.style.width=s(g(b).outerWidth()-c)}function Db(a,b){var c=Fb(a,b);if(0>c)return null; +var e=a.aoData[c];return!e.nTr?g("").html(v(a,c,b,"display"))[0]:e.anCells[b]}function Fb(a,b){for(var c,e=-1,d=-1,f=0,h=a.aoData.length;fe&&(e=c.length,d=f);return d}function s(a){return null===a?"0px":"number"==typeof a?0>a?"0px":a+"px":a.match(/\d$/)?a+"px":a}function Gb(){if(!p.__scrollbarWidth){var a=g("

      ").css({width:"100%",height:200,padding:0})[0],b=g("

      ").css({position:"absolute",top:0,left:0,width:200,height:150,padding:0, +overflow:"hidden",visibility:"hidden"}).append(a).appendTo("body"),c=a.offsetWidth;b.css("overflow","scroll");a=a.offsetWidth;c===a&&(a=b[0].clientWidth);b.remove();p.__scrollbarWidth=c-a}return p.__scrollbarWidth}function T(a){var b,c,e=[],d=a.aoColumns,f,h,i,j;b=a.aaSortingFixed;c=g.isPlainObject(b);var n=[];f=function(a){a.length&&!g.isArray(a[0])?n.push(a):n.push.apply(n,a)};g.isArray(b)&&f(b);c&&b.pre&&f(b.pre);f(a.aaSorting);c&&b.post&&f(b.post);for(a=0;ad?1:0,0!==c)return"asc"===g.dir?c:-c;c=e[a];d=e[b];return cd?1:0}):j.sort(function(a,b){var c,h,g,i,j=n.length,l=f[a]._aSortData,p=f[b]._aSortData;for(g=0;gh?1:0})}a.bSorted=!0}function Ib(a){for(var b,c,e=a.aoColumns,d=T(a),a=a.oLanguage.oAria,f=0,h=e.length;f/g,"");var j=c.nTh;j.removeAttribute("aria-sort");c.bSortable&&(0d?d+1:3));d=0;for(f=e.length;dd?d+1:3))}a.aLastSort=e}function Hb(a,b){var c=a.aoColumns[b],e=p.ext.order[c.sSortDataType],d;e&&(d=e.call(a.oInstance,a,b,$(a,b)));for(var f,h=p.ext.type.order[c.sType+"-pre"],g=0,j=a.aoData.length;g< +j;g++)if(c=a.aoData[g],c._aSortData||(c._aSortData=[]),!c._aSortData[b]||e)f=e?d[g]:v(a,g,b,"sort"),c._aSortData[b]=h?h(f):f}function xa(a){if(a.oFeatures.bStateSave&&!a.bDestroying){var b={time:+new Date,start:a._iDisplayStart,length:a._iDisplayLength,order:g.extend(!0,[],a.aaSorting),search:yb(a.oPreviousSearch),columns:g.map(a.aoColumns,function(b,e){return{visible:b.bVisible,search:yb(a.aoPreSearchCols[e])}})};u(a,"aoStateSaveParams","stateSaveParams",[a,b]);a.oSavedState=b;a.fnStateSaveCallback.call(a.oInstance, +a,b)}}function Jb(a){var b,c,e=a.aoColumns;if(a.oFeatures.bStateSave){var d=a.fnStateLoadCallback.call(a.oInstance,a);if(d&&d.time&&(b=u(a,"aoStateLoadParams","stateLoadParams",[a,d]),-1===g.inArray(!1,b)&&(b=a.iStateDuration,!(0=e.length?[0,c[1]]:c)});g.extend(a.oPreviousSearch, +zb(d.search));b=0;for(c=d.columns.length;b=c&&(b=c-e);b-=b%e;if(-1===e||0>b)b=0;a._iDisplayStart=b}function Oa(a,b){var c=a.renderer,e=p.ext.renderer[b];return g.isPlainObject(c)&& +c[b]?e[c[b]]||e._:"string"===typeof c?e[c]||e._:e._}function A(a){return a.oFeatures.bServerSide?"ssp":a.ajax||a.sAjaxSource?"ajax":"dom"}function Va(a,b){var c=[],c=Lb.numbers_length,e=Math.floor(c/2);b<=c?c=U(0,b):a<=e?(c=U(0,c-2),c.push("ellipsis"),c.push(b-1)):(a>=b-1-e?c=U(b-(c-2),b):(c=U(a-1,a+2),c.push("ellipsis"),c.push(b-1)),c.splice(0,0,"ellipsis"),c.splice(0,0,0));c.DT_el="span";return c}function cb(a){g.each({num:function(b){return za(b,a)},"num-fmt":function(b){return za(b,a,Wa)},"html-num":function(b){return za(b, +a,Aa)},"html-num-fmt":function(b){return za(b,a,Aa,Wa)}},function(b,c){w.type.order[b+a+"-pre"]=c;b.match(/^html\-/)&&(w.type.search[b+a]=w.type.search.html)})}function Mb(a){return function(){var b=[ya(this[p.ext.iApiIndex])].concat(Array.prototype.slice.call(arguments));return p.ext.internal[a].apply(this,b)}}var p,w,q,r,t,Xa={},Nb=/[\r\n]/g,Aa=/<.*?>/g,$b=/^[\w\+\-]/,ac=/[\w\+\-]$/,Xb=RegExp("(\\/|\\.|\\*|\\+|\\?|\\||\\(|\\)|\\[|\\]|\\{|\\}|\\\\|\\$|\\^|\\-)","g"),Wa=/[',$\u00a3\u20ac\u00a5%\u2009\u202F]/g, +H=function(a){return!a||!0===a||"-"===a?!0:!1},Ob=function(a){var b=parseInt(a,10);return!isNaN(b)&&isFinite(a)?b:null},Pb=function(a,b){Xa[b]||(Xa[b]=RegExp(ua(b),"g"));return"string"===typeof a&&"."!==b?a.replace(/\./g,"").replace(Xa[b],"."):a},Ya=function(a,b,c){var e="string"===typeof a;b&&e&&(a=Pb(a,b));c&&e&&(a=a.replace(Wa,""));return H(a)||!isNaN(parseFloat(a))&&isFinite(a)},Qb=function(a,b,c){return H(a)?!0:!(H(a)||"string"===typeof a)?null:Ya(a.replace(Aa,""),b,c)?!0:null},C=function(a, +b,c){var e=[],d=0,f=a.length;if(c!==l)for(;d")[0],Yb=va.textContent!==l,Zb=/<.*?>/g;p=function(a){this.$=function(a,b){return this.api(!0).$(a,b)};this._=function(a,b){return this.api(!0).rows(a,b).data()};this.api=function(a){return a?new q(ya(this[w.iApiIndex])):new q(this)};this.fnAddData=function(a,b){var c=this.api(!0),e=g.isArray(a)&&(g.isArray(a[0])||g.isPlainObject(a[0]))? +c.rows.add(a):c.row.add(a);(b===l||b)&&c.draw();return e.flatten().toArray()};this.fnAdjustColumnSizing=function(a){var b=this.api(!0).columns.adjust(),c=b.settings()[0],e=c.oScroll;a===l||a?b.draw(!1):(""!==e.sX||""!==e.sY)&&Y(c)};this.fnClearTable=function(a){var b=this.api(!0).clear();(a===l||a)&&b.draw()};this.fnClose=function(a){this.api(!0).row(a).child.hide()};this.fnDeleteRow=function(a,b,c){var e=this.api(!0),a=e.rows(a),d=a.settings()[0],g=d.aoData[a[0][0]];a.remove();b&&b.call(this,d,g); +(c===l||c)&&e.draw();return g};this.fnDestroy=function(a){this.api(!0).destroy(a)};this.fnDraw=function(a){this.api(!0).draw(!a)};this.fnFilter=function(a,b,c,e,d,g){d=this.api(!0);null===b||b===l?d.search(a,c,e,g):d.column(b).search(a,c,e,g);d.draw()};this.fnGetData=function(a,b){var c=this.api(!0);if(a!==l){var e=a.nodeName?a.nodeName.toLowerCase():"";return b!==l||"td"==e||"th"==e?c.cell(a,b).data():c.row(a).data()||null}return c.data().toArray()};this.fnGetNodes=function(a){var b=this.api(!0); +return a!==l?b.row(a).node():b.rows().nodes().flatten().toArray()};this.fnGetPosition=function(a){var b=this.api(!0),c=a.nodeName.toUpperCase();return"TR"==c?b.row(a).index():"TD"==c||"TH"==c?(a=b.cell(a).index(),[a.row,a.columnVisible,a.column]):null};this.fnIsOpen=function(a){return this.api(!0).row(a).child.isShown()};this.fnOpen=function(a,b,c){return this.api(!0).row(a).child(b,c).show().child()[0]};this.fnPageChange=function(a,b){var c=this.api(!0).page(a);(b===l||b)&&c.draw(!1)};this.fnSetColumnVis= +function(a,b,c){a=this.api(!0).column(a).visible(b);(c===l||c)&&a.columns.adjust().draw()};this.fnSettings=function(){return ya(this[w.iApiIndex])};this.fnSort=function(a){this.api(!0).order(a).draw()};this.fnSortListener=function(a,b,c){this.api(!0).order.listener(a,b,c)};this.fnUpdate=function(a,b,c,e,d){var g=this.api(!0);c===l||null===c?g.row(b).data(a):g.cell(b,c).data(a);(d===l||d)&&g.columns.adjust();(e===l||e)&&g.draw();return 0};this.fnVersionCheck=w.fnVersionCheck;var b=this,c=a===l,e=this.length; +c&&(a={});this.oApi=this.internal=w.internal;for(var d in p.ext.internal)d&&(this[d]=Mb(d));this.each(function(){var d={},d=1t<"F"ip>'),k.renderer)? +g.isPlainObject(k.renderer)&&!k.renderer.header&&(k.renderer.header="jqueryui"):k.renderer="jqueryui":g.extend(j,p.ext.classes,d.oClasses);g(this).addClass(j.sTable);if(""!==k.oScroll.sX||""!==k.oScroll.sY)k.oScroll.iBarWidth=Gb();!0===k.oScroll.sX&&(k.oScroll.sX="100%");k.iInitDisplayStart===l&&(k.iInitDisplayStart=d.iDisplayStart,k._iDisplayStart=d.iDisplayStart);null!==d.iDeferLoading&&(k.bDeferLoading=!0,h=g.isArray(d.iDeferLoading),k._iRecordsDisplay=h?d.iDeferLoading[0]:d.iDeferLoading,k._iRecordsTotal= +h?d.iDeferLoading[1]:d.iDeferLoading);var r=k.oLanguage;g.extend(!0,r,d.oLanguage);""!==r.sUrl&&(g.ajax({dataType:"json",url:r.sUrl,success:function(a){O(a);G(m.oLanguage,a);g.extend(true,r,a);ga(k)},error:function(){ga(k)}}),n=!0);null===d.asStripeClasses&&(k.asStripeClasses=[j.sStripeOdd,j.sStripeEven]);var h=k.asStripeClasses,q=g("tbody tr:eq(0)",this);-1!==g.inArray(!0,g.map(h,function(a){return q.hasClass(a)}))&&(g("tbody tr",this).removeClass(h.join(" ")),k.asDestroyStripes=h.slice());var o= +[],s,h=this.getElementsByTagName("thead");0!==h.length&&(da(k.aoHeader,h[0]),o=pa(k));if(null===d.aoColumns){s=[];h=0;for(i=o.length;h").appendTo(this));k.nTHead=i[0];i=g(this).children("tbody");0===i.length&&(i=g("").appendTo(this));k.nTBody=i[0];i=g(this).children("tfoot");if(0===i.length&&0").appendTo(this);0===i.length||0===i.children().length?g(this).addClass(j.sNoFooter): +0a?new q(b[a],this[a]):null},filter:function(a){var b=[];if(y.filter)b=y.filter.call(this,a,this);else for(var c=0,e=this.length;c").addClass(b);g("td",c).addClass(b).html(a)[0].colSpan=aa(e);d.push(c[0])}};if(g.isArray(a)||a instanceof g)for(var h=0,i=a.length;h=0?b:h.length+b];if(typeof a==="function"){var d=Ba(c,f);return g.map(h,function(b,f){return a(f,Vb(c,f,0,0,d),j[f])?f:null})}var k=typeof a==="string"?a.match(cc):"";if(k)switch(k[2]){case "visIdx":case "visible":b= +parseInt(k[1],10);if(b<0){var l=g.map(h,function(a,b){return a.bVisible?b:null});return[l[l.length+b]]}return[ka(c,b)];case "name":return g.map(i,function(a,b){return a===k[1]?b:null})}else return g(j).filter(a).map(function(){return g.inArray(this,j)}).toArray()})},1);c.selector.cols=a;c.selector.opts=b;return c});t("columns().header()","column().header()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].nTh},1)});t("columns().footer()","column().footer()",function(){return this.iterator("column", +function(a,b){return a.aoColumns[b].nTf},1)});t("columns().data()","column().data()",function(){return this.iterator("column-rows",Vb,1)});t("columns().dataSrc()","column().dataSrc()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].mData},1)});t("columns().cache()","column().cache()",function(a){return this.iterator("column-rows",function(b,c,e,d,f){return ha(b.aoData,f,"search"===a?"_aFilterData":"_aSortData",c)},1)});t("columns().nodes()","column().nodes()",function(){return this.iterator("column-rows", +function(a,b,c,e,d){return ha(a.aoData,d,"anCells",b)},1)});t("columns().visible()","column().visible()",function(a,b){return this.iterator("column",function(c,e){if(a===l)return c.aoColumns[e].bVisible;var d=c.aoColumns,f=d[e],h=c.aoData,i,j,n;if(a!==l&&f.bVisible!==a){if(a){var m=g.inArray(!0,C(d,"bVisible"),e+1);i=0;for(j=h.length;ie;return!0};p.isDataTable=p.fnIsDataTable=function(a){var b=g(a).get(0),c=!1;g.each(p.settings,function(a,d){if(d.nTable===b||d.nScrollHead===b||d.nScrollFoot===b)c=!0});return c};p.tables=p.fnTables=function(a){return g.map(p.settings,function(b){if(!a||a&&g(b.nTable).is(":visible"))return b.nTable})};p.util={throttle:ta,escapeRegex:ua}; +p.camelToHungarian=G;r("$()",function(a,b){var c=this.rows(b).nodes(),c=g(c);return g([].concat(c.filter(a).toArray(),c.find(a).toArray()))});g.each(["on","one","off"],function(a,b){r(b+"()",function(){var a=Array.prototype.slice.call(arguments);a[0].match(/\.dt\b/)||(a[0]+=".dt");var e=g(this.tables().nodes());e[b].apply(e,a);return this})});r("clear()",function(){return this.iterator("table",function(a){na(a)})});r("settings()",function(){return new q(this.context,this.context)});r("data()",function(){return this.iterator("table", +function(a){return C(a.aoData,"_aData")}).flatten()});r("destroy()",function(a){a=a||!1;return this.iterator("table",function(b){var c=b.nTableWrapper.parentNode,e=b.oClasses,d=b.nTable,f=b.nTBody,h=b.nTHead,i=b.nTFoot,j=g(d),f=g(f),l=g(b.nTableWrapper),m=g.map(b.aoData,function(a){return a.nTr}),o;b.bDestroying=!0;u(b,"aoDestroyCallback","destroy",[b]);a||(new q(b)).columns().visible(!0);l.unbind(".DT").find(":not(tbody *)").unbind(".DT");g(Da).unbind(".DT-"+b.sInstance);d!=h.parentNode&&(j.children("thead").detach(), +j.append(h));i&&d!=i.parentNode&&(j.children("tfoot").detach(),j.append(i));j.detach();l.detach();b.aaSorting=[];b.aaSortingFixed=[];wa(b);g(m).removeClass(b.asStripeClasses.join(" "));g("th, td",h).removeClass(e.sSortable+" "+e.sSortableAsc+" "+e.sSortableDesc+" "+e.sSortableNone);b.bJUI&&(g("th span."+e.sSortIcon+", td span."+e.sSortIcon,h).detach(),g("th, td",h).each(function(){var a=g("div."+e.sSortJUIWrapper,this);g(this).append(a.contents());a.detach()}));!a&&c&&c.insertBefore(d,b.nTableReinsertBefore); +f.children().detach();f.append(m);j.css("width",b.sDestroyWidth).removeClass(e.sTable);(o=b.asDestroyStripes.length)&&f.children().each(function(a){g(this).addClass(b.asDestroyStripes[a%o])});c=g.inArray(b,p.settings);-1!==c&&p.settings.splice(c,1)})});p.version="1.10.4";p.settings=[];p.models={};p.models.oSearch={bCaseInsensitive:!0,sSearch:"",bRegex:!1,bSmart:!0};p.models.oRow={nTr:null,anCells:null,_aData:[],_aSortData:null,_aFilterData:null,_sFilterRow:null,_sRowStripe:"",src:null};p.models.oColumn= +{idx:null,aDataSort:null,asSorting:null,bSearchable:null,bSortable:null,bVisible:null,_sManualType:null,_bAttrSrc:!1,fnCreatedCell:null,fnGetData:null,fnSetData:null,mData:null,mRender:null,nTh:null,nTf:null,sClass:null,sContentPadding:null,sDefaultContent:null,sName:null,sSortDataType:"std",sSortingClass:null,sSortingClassJUI:null,sTitle:null,sType:null,sWidth:null,sWidthOrig:null};p.defaults={aaData:null,aaSorting:[[0,"asc"]],aaSortingFixed:[],ajax:null,aLengthMenu:[10,25,50,100],aoColumns:null, +aoColumnDefs:null,aoSearchCols:[],asStripeClasses:null,bAutoWidth:!0,bDeferRender:!1,bDestroy:!1,bFilter:!0,bInfo:!0,bJQueryUI:!1,bLengthChange:!0,bPaginate:!0,bProcessing:!1,bRetrieve:!1,bScrollCollapse:!1,bServerSide:!1,bSort:!0,bSortMulti:!0,bSortCellsTop:!1,bSortClasses:!0,bStateSave:!1,fnCreatedRow:null,fnDrawCallback:null,fnFooterCallback:null,fnFormatNumber:function(a){return a.toString().replace(/\B(?=(\d{3})+(?!\d))/g,this.oLanguage.sThousands)},fnHeaderCallback:null,fnInfoCallback:null, +fnInitComplete:null,fnPreDrawCallback:null,fnRowCallback:null,fnServerData:null,fnServerParams:null,fnStateLoadCallback:function(a){try{return JSON.parse((-1===a.iStateDuration?sessionStorage:localStorage).getItem("DataTables_"+a.sInstance+"_"+location.pathname))}catch(b){}},fnStateLoadParams:null,fnStateLoaded:null,fnStateSaveCallback:function(a,b){try{(-1===a.iStateDuration?sessionStorage:localStorage).setItem("DataTables_"+a.sInstance+"_"+location.pathname,JSON.stringify(b))}catch(c){}},fnStateSaveParams:null, +iStateDuration:7200,iDeferLoading:null,iDisplayLength:10,iDisplayStart:0,iTabIndex:0,oClasses:{},oLanguage:{oAria:{sSortAscending:": activate to sort column ascending",sSortDescending:": activate to sort column descending"},oPaginate:{sFirst:"First",sLast:"Last",sNext:"Next",sPrevious:"Previous"},sEmptyTable:"No data available in table",sInfo:"Showing _START_ to _END_ of _TOTAL_ entries",sInfoEmpty:"Showing 0 to 0 of 0 entries",sInfoFiltered:"(filtered from _MAX_ total entries)",sInfoPostFix:"",sDecimal:"", +sThousands:",",sLengthMenu:"Show _MENU_ entries",sLoadingRecords:"Loading...",sProcessing:"Processing...",sSearch:"Search:",sSearchPlaceholder:"",sUrl:"",sZeroRecords:"No matching records found"},oSearch:g.extend({},p.models.oSearch),sAjaxDataProp:"data",sAjaxSource:null,sDom:"lfrtip",searchDelay:null,sPaginationType:"simple_numbers",sScrollX:"",sScrollXInner:"",sScrollY:"",sServerMethod:"GET",renderer:null};V(p.defaults);p.defaults.column={aDataSort:null,iDataSort:-1,asSorting:["asc","desc"],bSearchable:!0, +bSortable:!0,bVisible:!0,fnCreatedCell:null,mData:null,mRender:null,sCellType:"td",sClass:"",sContentPadding:"",sDefaultContent:null,sName:"",sSortDataType:"std",sTitle:null,sType:null,sWidth:null};V(p.defaults.column);p.models.oSettings={oFeatures:{bAutoWidth:null,bDeferRender:null,bFilter:null,bInfo:null,bLengthChange:null,bPaginate:null,bProcessing:null,bServerSide:null,bSort:null,bSortMulti:null,bSortClasses:null,bStateSave:null},oScroll:{bCollapse:null,iBarWidth:0,sX:null,sXInner:null,sY:null}, +oLanguage:{fnInfoCallback:null},oBrowser:{bScrollOversize:!1,bScrollbarLeft:!1},ajax:null,aanFeatures:[],aoData:[],aiDisplay:[],aiDisplayMaster:[],aoColumns:[],aoHeader:[],aoFooter:[],oPreviousSearch:{},aoPreSearchCols:[],aaSorting:null,aaSortingFixed:[],asStripeClasses:null,asDestroyStripes:[],sDestroyWidth:0,aoRowCallback:[],aoHeaderCallback:[],aoFooterCallback:[],aoDrawCallback:[],aoRowCreatedCallback:[],aoPreDrawCallback:[],aoInitComplete:[],aoStateSaveParams:[],aoStateLoadParams:[],aoStateLoaded:[], +sTableId:"",nTable:null,nTHead:null,nTFoot:null,nTBody:null,nTableWrapper:null,bDeferLoading:!1,bInitialised:!1,aoOpenRows:[],sDom:null,searchDelay:null,sPaginationType:"two_button",iStateDuration:0,aoStateSave:[],aoStateLoad:[],oSavedState:null,oLoadedState:null,sAjaxSource:null,sAjaxDataProp:null,bAjaxDataGet:!0,jqXHR:null,json:l,oAjaxData:l,fnServerData:null,aoServerParams:[],sServerMethod:null,fnFormatNumber:null,aLengthMenu:null,iDraw:0,bDrawing:!1,iDrawError:-1,_iDisplayLength:10,_iDisplayStart:0, +_iRecordsTotal:0,_iRecordsDisplay:0,bJUI:null,oClasses:{},bFiltered:!1,bSorted:!1,bSortCellsTop:null,oInit:null,aoDestroyCallback:[],fnRecordsTotal:function(){return"ssp"==A(this)?1*this._iRecordsTotal:this.aiDisplayMaster.length},fnRecordsDisplay:function(){return"ssp"==A(this)?1*this._iRecordsDisplay:this.aiDisplay.length},fnDisplayEnd:function(){var a=this._iDisplayLength,b=this._iDisplayStart,c=b+a,e=this.aiDisplay.length,d=this.oFeatures,f=d.bPaginate;return d.bServerSide?!1===f||-1===a?b+e: +Math.min(b+a,this._iRecordsDisplay):!f||c>e||-1===a?e:c},oInstance:null,sInstance:null,iTabIndex:0,nScrollHead:null,nScrollFoot:null,aLastSort:[],oPlugins:{}};p.ext=w={classes:{},errMode:"alert",feature:[],search:[],internal:{},legacy:{ajax:null},pager:{},renderer:{pageButton:{},header:{}},order:{},type:{detect:[],search:{},order:{}},_unique:0,fnVersionCheck:p.fnVersionCheck,iApiIndex:0,oJUIClasses:{},sVersion:p.version};g.extend(w,{afnFiltering:w.search,aTypes:w.type.detect,ofnSearch:w.type.search, +oSort:w.type.order,afnSortData:w.order,aoFeatures:w.feature,oApi:w.internal,oStdClasses:w.classes,oPagination:w.pager});g.extend(p.ext.classes,{sTable:"dataTable",sNoFooter:"no-footer",sPageButton:"paginate_button",sPageButtonActive:"current",sPageButtonDisabled:"disabled",sStripeOdd:"odd",sStripeEven:"even",sRowEmpty:"dataTables_empty",sWrapper:"dataTables_wrapper",sFilter:"dataTables_filter",sInfo:"dataTables_info",sPaging:"dataTables_paginate paging_",sLength:"dataTables_length",sProcessing:"dataTables_processing", +sSortAsc:"sorting_asc",sSortDesc:"sorting_desc",sSortable:"sorting",sSortableAsc:"sorting_asc_disabled",sSortableDesc:"sorting_desc_disabled",sSortableNone:"sorting_disabled",sSortColumn:"sorting_",sFilterInput:"",sLengthSelect:"",sScrollWrapper:"dataTables_scroll",sScrollHead:"dataTables_scrollHead",sScrollHeadInner:"dataTables_scrollHeadInner",sScrollBody:"dataTables_scrollBody",sScrollFoot:"dataTables_scrollFoot",sScrollFootInner:"dataTables_scrollFootInner",sHeaderTH:"",sFooterTH:"",sSortJUIAsc:"", +sSortJUIDesc:"",sSortJUI:"",sSortJUIAscAllowed:"",sSortJUIDescAllowed:"",sSortJUIWrapper:"",sSortIcon:"",sJUIHeader:"",sJUIFooter:""});var Ca="",Ca="",E=Ca+"ui-state-default",ia=Ca+"css_right ui-icon ui-icon-",Wb=Ca+"fg-toolbar ui-toolbar ui-widget-header ui-helper-clearfix";g.extend(p.ext.oJUIClasses,p.ext.classes,{sPageButton:"fg-button ui-button "+E,sPageButtonActive:"ui-state-disabled",sPageButtonDisabled:"ui-state-disabled",sPaging:"dataTables_paginate fg-buttonset ui-buttonset fg-buttonset-multi ui-buttonset-multi paging_", +sSortAsc:E+" sorting_asc",sSortDesc:E+" sorting_desc",sSortable:E+" sorting",sSortableAsc:E+" sorting_asc_disabled",sSortableDesc:E+" sorting_desc_disabled",sSortableNone:E+" sorting_disabled",sSortJUIAsc:ia+"triangle-1-n",sSortJUIDesc:ia+"triangle-1-s",sSortJUI:ia+"carat-2-n-s",sSortJUIAscAllowed:ia+"carat-1-n",sSortJUIDescAllowed:ia+"carat-1-s",sSortJUIWrapper:"DataTables_sort_wrapper",sSortIcon:"DataTables_sort_icon",sScrollHead:"dataTables_scrollHead "+E,sScrollFoot:"dataTables_scrollFoot "+E, +sHeaderTH:E,sFooterTH:E,sJUIHeader:Wb+" ui-corner-tl ui-corner-tr",sJUIFooter:Wb+" ui-corner-bl ui-corner-br"});var Lb=p.ext.pager;g.extend(Lb,{simple:function(){return["previous","next"]},full:function(){return["first","previous","next","last"]},simple_numbers:function(a,b){return["previous",Va(a,b),"next"]},full_numbers:function(a,b){return["first","previous",Va(a,b),"next","last"]},_numbers:Va,numbers_length:7});g.extend(!0,p.ext.renderer,{pageButton:{_:function(a,b,c,e,d,f){var h=a.oClasses,i= +a.oLanguage.oPaginate,j,l,m=0,o=function(b,e){var k,p,r,q,s=function(b){Sa(a,b.data.action,true)};k=0;for(p=e.length;k").appendTo(b);o(r,q)}else{l=j="";switch(q){case "ellipsis":b.append("");break;case "first":j=i.sFirst;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "previous":j=i.sPrevious;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "next":j=i.sNext;l=q+(d",{"class":h.sPageButton+" "+l,"aria-controls":a.sTableId,"data-dt-idx":m,tabindex:a.iTabIndex,id:c===0&&typeof q==="string"?a.sTableId+"_"+q:null}).html(j).appendTo(b);Ua(r,{action:q},s);m++}}}};try{var k=g(P.activeElement).data("dt-idx");o(g(b).empty(),e);k!==null&&g(b).find("[data-dt-idx="+k+"]").focus()}catch(p){}}}});g.extend(p.ext.type.detect,[function(a,b){var c=b.oLanguage.sDecimal; +return Ya(a,c)?"num"+c:null},function(a){if(a&&!(a instanceof Date)&&(!$b.test(a)||!ac.test(a)))return null;var b=Date.parse(a);return null!==b&&!isNaN(b)||H(a)?"date":null},function(a,b){var c=b.oLanguage.sDecimal;return Ya(a,c,!0)?"num-fmt"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c)?"html-num"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c,!0)?"html-num-fmt"+c:null},function(a){return H(a)||"string"===typeof a&&-1!==a.indexOf("<")?"html":null}]);g.extend(p.ext.type.search, +{html:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," ").replace(Aa,""):""},string:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," "):a}});var za=function(a,b,c,e){if(0!==a&&(!a||"-"===a))return-Infinity;b&&(a=Pb(a,b));a.replace&&(c&&(a=a.replace(c,"")),e&&(a=a.replace(e,"")));return 1*a};g.extend(w.type.order,{"date-pre":function(a){return Date.parse(a)||0},"html-pre":function(a){return H(a)?"":a.replace?a.replace(/<.*?>/g,"").toLowerCase():a+""},"string-pre":function(a){return H(a)? +"":"string"===typeof a?a.toLowerCase():!a.toString?"":a.toString()},"string-asc":function(a,b){return ab?1:0},"string-desc":function(a,b){return ab?-1:0}});cb("");g.extend(!0,p.ext.renderer,{header:{_:function(a,b,c,e){g(a.nTable).on("order.dt.DT",function(d,f,h,g){if(a===f){d=c.idx;b.removeClass(c.sSortingClass+" "+e.sSortAsc+" "+e.sSortDesc).addClass(g[d]=="asc"?e.sSortAsc:g[d]=="desc"?e.sSortDesc:c.sSortingClass)}})},jqueryui:function(a,b,c,e){g("
      ").addClass(e.sSortJUIWrapper).append(b.contents()).append(g("").addClass(e.sSortIcon+ +" "+c.sSortingClassJUI)).appendTo(b);g(a.nTable).on("order.dt.DT",function(d,f,g,i){if(a===f){d=c.idx;b.removeClass(e.sSortAsc+" "+e.sSortDesc).addClass(i[d]=="asc"?e.sSortAsc:i[d]=="desc"?e.sSortDesc:c.sSortingClass);b.find("span."+e.sSortIcon).removeClass(e.sSortJUIAsc+" "+e.sSortJUIDesc+" "+e.sSortJUI+" "+e.sSortJUIAscAllowed+" "+e.sSortJUIDescAllowed).addClass(i[d]=="asc"?e.sSortJUIAsc:i[d]=="desc"?e.sSortJUIDesc:c.sSortingClassJUI)}})}}});p.render={number:function(a,b,c,e){return{display:function(d){var f= +0>d?"-":"",d=Math.abs(parseFloat(d)),g=parseInt(d,10),d=c?b+(d-g).toFixed(c).substring(2):"";return f+(e||"")+g.toString().replace(/\B(?=(\d{3})+(?!\d))/g,a)+d}}}};g.extend(p.ext.internal,{_fnExternApiFunc:Mb,_fnBuildAjax:qa,_fnAjaxUpdate:jb,_fnAjaxParameters:sb,_fnAjaxUpdateDraw:tb,_fnAjaxDataSrc:ra,_fnAddColumn:Ea,_fnColumnOptions:ja,_fnAdjustColumnSizing:X,_fnVisibleToColumnIndex:ka,_fnColumnIndexToVisible:$,_fnVisbleColumns:aa,_fnGetColumns:Z,_fnColumnTypes:Ga,_fnApplyColumnDefs:hb,_fnHungarianMap:V, +_fnCamelToHungarian:G,_fnLanguageCompat:O,_fnBrowserDetect:fb,_fnAddData:I,_fnAddTr:la,_fnNodeToDataIndex:function(a,b){return b._DT_RowIndex!==l?b._DT_RowIndex:null},_fnNodeToColumnIndex:function(a,b,c){return g.inArray(c,a.aoData[b].anCells)},_fnGetCellData:v,_fnSetCellData:Ha,_fnSplitObjNotation:Ja,_fnGetObjectDataFn:W,_fnSetObjectDataFn:Q,_fnGetDataMaster:Ka,_fnClearTable:na,_fnDeleteIndex:oa,_fnInvalidate:ca,_fnGetRowElements:ma,_fnCreateTr:Ia,_fnBuildHead:ib,_fnDrawHead:ea,_fnDraw:L,_fnReDraw:M, +_fnAddOptionsHtml:lb,_fnDetectHeader:da,_fnGetUniqueThs:pa,_fnFeatureHtmlFilter:nb,_fnFilterComplete:fa,_fnFilterCustom:wb,_fnFilterColumn:vb,_fnFilter:ub,_fnFilterCreateSearch:Pa,_fnEscapeRegex:ua,_fnFilterData:xb,_fnFeatureHtmlInfo:qb,_fnUpdateInfo:Ab,_fnInfoMacros:Bb,_fnInitialise:ga,_fnInitComplete:sa,_fnLengthChange:Qa,_fnFeatureHtmlLength:mb,_fnFeatureHtmlPaginate:rb,_fnPageChange:Sa,_fnFeatureHtmlProcessing:ob,_fnProcessingDisplay:B,_fnFeatureHtmlTable:pb,_fnScrollDraw:Y,_fnApplyToChildren:F, +_fnCalculateColumnWidths:Fa,_fnThrottle:ta,_fnConvertToWidth:Cb,_fnScrollingWidthAdjust:Eb,_fnGetWidestNode:Db,_fnGetMaxLenString:Fb,_fnStringToCss:s,_fnScrollBarWidth:Gb,_fnSortFlatten:T,_fnSort:kb,_fnSortAria:Ib,_fnSortListener:Ta,_fnSortAttachListener:Na,_fnSortingClasses:wa,_fnSortData:Hb,_fnSaveState:xa,_fnLoadState:Jb,_fnSettingsFromNode:ya,_fnLog:R,_fnMap:D,_fnBindAction:Ua,_fnCallbackReg:x,_fnCallbackFire:u,_fnLengthOverflow:Ra,_fnRenderer:Oa,_fnDataSource:A,_fnRowAttributes:La,_fnCalculateEnd:function(){}}); +g.fn.dataTable=p;g.fn.dataTableSettings=p.settings;g.fn.dataTableExt=p.ext;g.fn.DataTable=function(a){return g(this).dataTable(a).api()};g.each(p,function(a,b){g.fn.DataTable[a]=b});return g.fn.dataTable};"function"===typeof define&&define.amd?define("datatables",["jquery"],O):"object"===typeof exports?O(require("jquery")):jQuery&&!jQuery.fn.dataTable&&O(jQuery)})(window,document); diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js new file mode 100644 index 000000000000..14925bf93d0f --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js @@ -0,0 +1,592 @@ +/* +Shameless port of a shameless port +@defunkt => @janl => @aq + +See http://github.com/defunkt/mustache for more info. +*/ + +;(function($) { + +/*! + * mustache.js - Logic-less {{mustache}} templates with JavaScript + * http://github.com/janl/mustache.js + */ + +/*global define: false*/ + +(function (root, factory) { + if (typeof exports === "object" && exports) { + factory(exports); // CommonJS + } else { + var mustache = {}; + factory(mustache); + if (typeof define === "function" && define.amd) { + define(mustache); // AMD + } else { + root.Mustache = mustache; // ++ + ++ + } else if (requestedIncomplete) {

      No incomplete applications found!

      } else {

      No completed applications found!

      ++ -

      Did you specify the correct logging directory? - Please verify your setting of - spark.history.fs.logDirectory and whether you have the permissions to - access it.
      It is also possible that your application did not run to - completion or did not stop the SparkContext. -

      +

      Did you specify the correct logging directory? + Please verify your setting of + spark.history.fs.logDirectory and whether you have the permissions to + access it.
      It is also possible that your application did not run to + completion or did not stop the SparkContext. +

      } - } - - { + } + + + { if (requestedIncomplete) { "Back to completed applications" } else { "Show incomplete applications" } - } - -
      + } +
      +
      UIUtils.basicSparkPage(content, "History Server") } - private val appHeader = Seq( - "App ID", - "App Name", - "Started", - "Completed", - "Duration", - "Spark User", - "Last Updated") - - private val appWithAttemptHeader = Seq( - "App ID", - "App Name", - "Attempt ID", - "Started", - "Completed", - "Duration", - "Spark User", - "Last Updated") - - private def rangeIndices( - range: Seq[Int], - condition: Int => Boolean, - showIncomplete: Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => - {nextPage} ) - } - - private def attemptRow( - renderAttemptIdColumn: Boolean, - info: ApplicationHistoryInfo, - attempt: ApplicationAttemptInfo, - isFirst: Boolean): Seq[Node] = { - val uiAddress = UIUtils.prependBaseUri(HistoryServer.getAttemptURI(info.id, attempt.attemptId)) - val startTime = UIUtils.formatDate(attempt.startTime) - val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-" - val duration = - if (attempt.endTime > 0) { - UIUtils.formatDuration(attempt.endTime - attempt.startTime) - } else { - "-" - } - val lastUpdated = UIUtils.formatDate(attempt.lastUpdated) - - { - if (isFirst) { - if (info.attempts.size > 1 || renderAttemptIdColumn) { - - {info.id} - - {info.name} - } else { - {info.id} - {info.name} - } - } else { - Nil - } - } - { - if (renderAttemptIdColumn) { - if (info.attempts.size > 1 && attempt.attemptId.isDefined) { - {attempt.attemptId.get} - } else { -   - } - } else { - Nil - } - } - {startTime} - {endTime} - - {duration} - {attempt.sparkUser} - {lastUpdated} - - } - - private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - attemptRow(false, info, info.attempts.head, true) - } - - private def appWithAttemptRow(info: ApplicationHistoryInfo): Seq[Node] = { - attemptRow(true, info, info.attempts.head, true) ++ - info.attempts.drop(1).flatMap(attemptRow(true, info, _, false)) - } - - private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + Array( - "page=" + linkPage, - "showIncomplete=" + showIncomplete - ).mkString("&")) + private def makePageLink(showIncomplete: Boolean): String = { + UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 0fc0fb59d861..0f3018368246 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -71,6 +71,13 @@ private[spark] object ApplicationsListResource { attemptId = internalAttemptInfo.attemptId, startTime = new Date(internalAttemptInfo.startTime), endTime = new Date(internalAttemptInfo.endTime), + duration = + if (internalAttemptInfo.endTime > 0) { + internalAttemptInfo.endTime - internalAttemptInfo.startTime + } else { + 0 + }, + lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, completed = internalAttemptInfo.completed ) @@ -93,6 +100,13 @@ private[spark] object ApplicationsListResource { attemptId = None, startTime = new Date(internal.startTime), endTime = new Date(internal.endTime), + duration = + if (internal.endTime > 0) { + internal.endTime - internal.startTime + } else { + 0 + }, + lastUpdated = new Date(internal.endTime), sparkUser = internal.desc.user, completed = completed )) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 3adf5b1109af..2b0079f5fd62 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -35,6 +35,8 @@ class ApplicationAttemptInfo private[spark]( val attemptId: Option[String], val startTime: Date, val endTime: Date, + val lastUpdated: Date, + val duration: Long, val sparkUser: String, val completed: Boolean = false) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index cf45414c4f78..6cc30eeaf5d8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -114,6 +114,8 @@ private[spark] class SparkUI private ( attemptId = None, startTime = new Date(startTime), endTime = new Date(-1), + duration = 0, + lastUpdated = new Date(startTime), sparkUser = "", completed = false )) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 1949c4b3cbf4..4ebee9093d41 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -157,11 +157,22 @@ private[spark] object UIUtils extends Logging { def commonHeaderNodes: Seq[Node] = { + + + + + + + + + diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index d575bf2f284b..5bbb4ceb9722 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -45,6 +55,8 @@ "attempts" : [ { "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -54,6 +66,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -63,7 +77,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index d575bf2f284b..5bbb4ceb9722 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -45,6 +55,8 @@ "attempts" : [ { "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -54,6 +66,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -63,7 +77,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index 483632a3956e..3f80a529a08b 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -4,7 +4,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index 4b85690fd919..508bdc17efe9 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -13,7 +15,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 15c2de8ef99e..5dca7d73de0c 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -46,8 +56,10 @@ { "startTime": "2015-02-28T00:02:38.277GMT", "endTime": "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser": "irashid", "completed": true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index 07489ad96414..cca32c791074 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -4,7 +4,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index 8f3d7160c723..1ea1779e8369 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -5,13 +5,17 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 18659fc0c18d..be55b2e0fe1b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -139,7 +139,24 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers code should be (HttpServletResponse.SC_OK) jsonOpt should be ('defined) errOpt should be (None) - val json = jsonOpt.get + val jsonOrg = jsonOpt.get + + // SPARK-10873 added the lastUpdated field for each application's attempt, + // the REST API returns the last modified time of EVENT LOG file for this field. + // It is not applicable to hard-code this dynamic field in a static expected file, + // so here we skip checking the lastUpdated field's value (setting it as ""). + val json = if (jsonOrg.indexOf("lastUpdated") >= 0) { + val subStrings = jsonOrg.split(",") + for (i <- subStrings.indices) { + if (subStrings(i).indexOf("lastUpdated") >= 0) { + subStrings(i) = "\"lastUpdated\":\"\"" + } + } + subStrings.mkString(",") + } else { + jsonOrg + } + val exp = IOUtils.toString(new FileInputStream( new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json"))) // compare the ASTs so formatting differences don't cause failures @@ -241,30 +258,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("generate history page with relative links") { - val historyServer = mock[HistoryServer] - val request = mock[HttpServletRequest] - val ui = mock[SparkUI] - val link = "/history/app1" - val info = new ApplicationHistoryInfo("app1", "app1", - List(ApplicationAttemptInfo(None, 0, 2, 1, "xxx", true))) - when(historyServer.getApplicationList()).thenReturn(Seq(info)) - when(ui.basePath).thenReturn(link) - when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) - val page = new HistoryPage(historyServer) - - // when - val response = page.render(request) - - // then - val links = response \\ "a" - val justHrefs = for { - l <- links - attrs <- l.attribute("href") - } yield (attrs.toString) - justHrefs should contain (UIUtils.prependBaseUri(resource = link)) - } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 968a2903f301..a3ae4d2b730f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,6 +45,10 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), // SPARK-12600 Remove SQL deprecated methods ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), From c5f745ede01831b59c57effa7de88c648b82c13d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 10:24:23 -0800 Subject: [PATCH 062/131] [SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible. generated code comparison for `hash(int, double, string, array)`: **before:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); } /* input[1, double] */ double value5 = i.getDouble(1); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); } /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { int result10 = value1; for (int index11 = 0; index11 < value9.numElements(); index11++) { if (!value9.isNullAt(index11)) { final int element12 = value9.getInt(index11); result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10); } } value1 = result10; } } ``` **after:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); /* input[1, double] */ double value5 = i.getDouble(1); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { for (int index10 = 0; index10 < value9.numElements(); index10++) { final int element11 = value9.getInt(index10); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1); } } rowWriter14.write(0, value1); return result12; } ``` Author: Wenchen Fan Closes #10974 from cloud-fan/codegen. --- .../spark/sql/catalyst/expressions/misc.scala | 155 ++++++++---------- 1 file changed, 69 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 493e0aae01af..8480c3f9a12f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" - val childrenHash = children.zipWithIndex.map { - case (child, dt) => - val childGen = child.gen(ctx) - val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) - s""" - ${childGen.code} - if (!${childGen.isNull}) { - ${childHash.code} - ${ev.value} = ${childHash.value}; - } - """ + val childrenHash = children.map { child => + val childGen = child.gen(ctx) + childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } }.mkString("\n") + s""" int ${ev.value} = $seed; $childrenHash """ } + private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execution + } + """ + } else { + "\n" + execution + } + } + + private def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + generateNullCheck(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + private def computeHash( input: String, dataType: DataType, - seed: String, - ctx: CodegenContext): ExprCode = { + result: String, + ctx: CodegenContext): String = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) + + def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" + def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" + def hashBytes(b: String): String = + s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" dataType match { - case NullType => inlineValue(seed) + case NullType => "" case BooleanType => hashInt(s"$input ? 1 : 0") case ByteType | ShortType | IntegerType | DateType => hashInt(input) case LongType | TimestampType => hashLong(input) @@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression hashLong(s"$input.toUnscaledLong()") } else { val bytes = ctx.freshName("bytes") - val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" - val offset = "Platform.BYTE_ARRAY_OFFSET" - val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - ExprCode(code, "false", result) + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ } case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" - val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" - inlineValue(monthsHash) - case BinaryType => - val offset = "Platform.BYTE_ARRAY_OFFSET" - inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" + s"$result = $hasher.hashInt($input.months, $microsecondsHash);" + case BinaryType => hashBytes(input) case StringType => val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" - inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, _) => - val result = ctx.freshName("result") + case ArrayType(et, containsNull) => val index = ctx.freshName("index") - val element = ctx.freshName("element") - val elementHash = computeHash(element, et, result, ctx) - val code = - s""" - int $result = $seed; - for (int $index = 0; $index < $input.numElements(); $index++) { - if (!$input.isNullAt($index)) { - final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; - ${elementHash.code} - $result = ${elementHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} + } + """ - case MapType(kt, vt, _) => - val result = ctx.freshName("result") + case MapType(kt, vt, valueContainsNull) => val index = ctx.freshName("index") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val key = ctx.freshName("key") - val value = ctx.freshName("value") - val keyHash = computeHash(key, kt, result, ctx) - val valueHash = computeHash(value, vt, result, ctx) - val code = - s""" - int $result = $seed; - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; - ${keyHash.code} - $result = ${keyHash.value}; - if (!$values.isNullAt($index)) { - final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; - ${valueHash.code} - $result = ${valueHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, kt, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} + } + """ case StructType(fields) => - val result = ctx.freshName("result") - val fieldsHash = fields.map(_.dataType).zipWithIndex.map { - case (dt, index) => - val field = ctx.freshName("field") - val fieldHash = computeHash(field, dt, result, ctx) - s""" - if (!$input.isNullAt($index)) { - final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; - ${fieldHash.code} - $result = ${fieldHash.value}; - } - """ + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") - val code = - s""" - int $result = $seed; - $fieldsHash - """ - ExprCode(code, "false", result) - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) } } } From 5f686cc8b74ea9e36f56c31f14df90d134fd9343 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Jan 2016 11:22:12 -0800 Subject: [PATCH 063/131] [SPARK-12656] [SQL] Implement Intersect with Left-semi Join Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: https://github.com/apache/spark/pull/10566 Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10630 from gatorsmile/IntersectBySemiJoin. --- .../sql/catalyst/analysis/Analyzer.scala | 113 ++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 ++- .../sql/catalyst/optimizer/Optimizer.scala | 45 ++++--- .../plans/logical/basicOperators.scala | 32 +++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 5 + .../optimizer/AggregateOptimizeSuite.scala | 12 -- .../optimizer/ReplaceOperatorSuite.scala | 59 +++++++++ .../optimizer/SetOperationSuite.scala | 15 +-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../spark/sql/execution/basicOperators.scala | 12 -- .../org/apache/spark/sql/DataFrameSuite.scala | 21 ++++ 11 files changed, 211 insertions(+), 122 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 33d76eeb2128..5fe700ee0067 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -344,6 +344,63 @@ class Analyzer( } } + /** + * Generate a new logical plan for the right child with different expression IDs + * for all conflicting attributes. + */ + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + + s"between $left and $right") + + right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + + case oldVersion @ Window(_, windowExpressions, _, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) + } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + right + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + newRight + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -388,57 +445,11 @@ class Analyzer( .map(_.asInstanceOf[NamedExpression]) a.copy(aggregateExpressions = expanded) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - - right.collect { - // Handle base relations that might appear more than once. - case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = oldVersion.newInstance() - (oldVersion, newVersion) - - // Handle projects that create conflicting aliases. - case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) - - case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - - case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => - val newOutput = oldVersion.generatorOutput.map(_.newInstance()) - (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - - case oldVersion @ Window(_, windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => - (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - } - // Only handle first case, others will be fixed on the next pass. - .headOption match { - case None => - /* - * No result implies that there is a logical plan node that produces new references - * that this rule cannot handle. When that is the case, there must be another rule - * that resolves these conflicts. Otherwise, the analysis will fail. - */ - j - case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } - } - j.copy(right = newRight) - } + // To resolve duplicate expression IDs for Join and Intersect + case j @ Join(left, right, _, _) if !j.duplicateResolved => + j.copy(right = dedupRight(left, right)) + case i @ Intersect(left, right) if !i.duplicateResolved => + i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on grandchild diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f2e78d97442e..4a2f2b8bc6e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,9 +214,8 @@ trait CheckAnalysis { s"""Only a single table generating function is allowed in a SELECT clause, found: | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) + case j: Join if !j.duplicateResolved => + val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) failAnalysis( s""" |Failure when resolving conflicting references in Join: @@ -224,6 +223,15 @@ trait CheckAnalysis { |Conflicting attributes: ${conflictingAttributes.mkString(",")} |""".stripMargin) + case i: Intersect if !i.duplicateResolved => + val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Intersect: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6addc2080648..f156b5d10acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, + ReplaceDistinctWithAggregate) :: Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down @@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] { } /** - * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * - * Intersect: - * It is not safe to pushdown Projections through it because we need to get the - * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * with deterministic condition. - * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters @@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrites an expression so that it can be pushed to the right side of a - * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * Union or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) - // Push down filter through INTERSECT - case Filter(condition, Intersect(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(left, right) - Filter(nondeterministic, - Intersect( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) - // Push down filter through EXCEPT case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) @@ -1054,6 +1040,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. + * {{{ + * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 + * }}} + * + * Note: + * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL. + * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated + * join conditions will be incorrect. + */ +object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right) => + assert(left.output.size == right.output.size) + val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e9c970cd0808..16f4b355b1b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -90,12 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - final override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -103,15 +99,30 @@ private[sql] object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + + // Intersect are only resolved if they don't introduce ambiguous expression ids, + // since the Optimizer will convert Intersect to Join. + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + duplicateResolved } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } /** Factory for constructing new `Union` nodes. */ @@ -169,13 +180,13 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { childrenResolved && expressions.forall(_.resolved) && - selfJoinResolved && + duplicateResolved && condition.forall(_.dataType == BooleanType) } } @@ -249,7 +260,7 @@ case class Range( end: Long, step: Long, numSlices: Int, - output: Seq[Attribute]) extends LeafNode { + output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { require(step != 0, "step cannot be 0") val numElements: BigInt = { val safeStart = BigInt(start) @@ -262,6 +273,9 @@ case class Range( } } + override def newInstance(): Range = + Range(start, end, step, numSlices, output.map(_.newInstance())) + override def statistics: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ab680282208c..1938bce02a17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("self intersect should resolve duplicate expression IDs") { + val plan = testRelation.intersect(testRelation) + assertAnalysisSuccess(plan) + } + test("SPARK-8654: invalid CAST in NULL IN(...) expression") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 37148a226f29..a4a12c0d62e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } - test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala new file mode 100644 index 000000000000..f8ae5d9be208 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", FixedPoint(100), + ReplaceDistinctWithAggregate, + ReplaceIntersectWithSemiJoin) :: Nil + } + + test("replace Intersect with Left-semi Join") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = Intersect(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Distinct with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 2283f7c008ba..b8ea32b4dfe0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) - val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) test("union: combine unions into one unions") { @@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest { comparePlans(combinedUnionsOptimized, unionOptimized3) } - test("intersect/except: filter to each side") { - val intersectQuery = testIntersect.where('b < 10) + test("except: filter to each side") { val exceptQuery = testExcept.where('c >= 5) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - val intersectCorrectAnswer = - Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } @@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest { } test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - comparePlans(intersectOptimized, intersectQuery.analyze) comparePlans(exceptOptimized, exceptQuery.analyze) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 60fbb595e575..9293e5514175 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -298,6 +298,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + case logical.Intersect(left, right) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil @@ -340,8 +343,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil - case logical.Intersect(left, right) => - execution.Intersect(planLater(left), planLater(right)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb4b..fd81531c9316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -420,18 +420,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { } } -/** - * Returns the rows in left that also appear in right using the built in spark - * intersection function. - */ -case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output: Seq[Attribute] = children.head.output - - protected override def doExecute(): RDD[InternalRow] = { - left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) - } -} - /** * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 09bbe57a43ce..4ff99bdf2937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) } test("intersect - nullability") { From 2b027e9a386fe4009f61ad03b169335af5a9a5c6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 29 Jan 2016 12:01:13 -0800 Subject: [PATCH 064/131] [SPARK-12818] Polishes spark-sketch module Fixes various minor code and Javadoc styling issues. Author: Cheng Lian Closes #10985 from liancheng/sketch-polishing. --- .../apache/spark/util/sketch/BitArray.java | 2 +- .../apache/spark/util/sketch/BloomFilter.java | 111 ++++++++++-------- .../spark/util/sketch/BloomFilterImpl.java | 40 ++++--- .../spark/util/sketch/CountMinSketch.java | 26 ++-- .../spark/util/sketch/CountMinSketchImpl.java | 12 ++ .../org/apache/spark/util/sketch/Utils.java | 2 +- 6 files changed, 110 insertions(+), 83 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 2a0484e324b1..480a0a79db32 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -22,7 +22,7 @@ import java.io.IOException; import java.util.Arrays; -public final class BitArray { +final class BitArray { private final long[] data; private long bitCount; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 81772fcea0ec..c0b425e72959 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -22,16 +22,10 @@ import java.io.OutputStream; /** - * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether - * an element is a member of a set. It returns false when the element is definitely not in the - * set, returns true when the element is probably in the set. - * - * Internally a Bloom filter is initialized with 2 information: how many space to use(number of - * bits) and how many hash values to calculate for each record. To get as lower false positive - * probability as possible, user should call {@link BloomFilter#create} to automatically pick a - * best combination of these 2 parameters. - * - * Currently the following data types are supported: + * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate + * containment test with one-sided error: if it claims that an item is contained in it, this + * might be in error, but if it claims that an item is not contained in it, then this is + * definitely true. Currently supported data types include: *
        *
      • {@link Byte}
      • *
      • {@link Short}
      • @@ -39,14 +33,17 @@ *
      • {@link Long}
      • *
      • {@link String}
      • *
      + * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * not actually been put in the {@code BloomFilter}. * - * The implementation is largely based on the {@code BloomFilter} class from guava. + * The implementation is largely based on the {@code BloomFilter} class from Guava. */ public abstract class BloomFilter { public enum Version { /** - * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * {@code BloomFilter} binary format version 1. All values written in big-endian order: *
        *
      • Version number, always 1 (32 bit)
      • *
      • Number of hash functions (32 bit)
      • @@ -68,14 +65,13 @@ int getVersionNumber() { } /** - * Returns the false positive probability, i.e. the probability that - * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that - * has not actually been put in the {@code BloomFilter}. + * Returns the probability that {@linkplain #mightContain(Object)} erroneously return {@code true} + * for an object that has not actually been put in the {@code BloomFilter}. * - *

        Ideally, this number should be close to the {@code fpp} parameter - * passed in to create this bloom filter, or smaller. If it is - * significantly higher, it is usually the case that too many elements (more than - * expected) have been put in the {@code BloomFilter}, degenerating it. + * Ideally, this number should be close to the {@code fpp} parameter passed in + * {@linkplain #create(long, double)}, or smaller. If it is significantly higher, it is usually + * the case that too many items (more than expected) have been put in the {@code BloomFilter}, + * degenerating it. */ public abstract double expectedFpp(); @@ -85,8 +81,8 @@ int getVersionNumber() { public abstract long bitSize(); /** - * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of - * {@link #mightContain(Object)} with the same element will always return {@code true}. + * Puts an item into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@linkplain #mightContain(Object)} with the same item will always return {@code true}. * * @return true if the bloom filter's bits changed as a result of this operation. If the bits * changed, this is definitely the first time {@code object} has been added to the @@ -98,19 +94,19 @@ int getVersionNumber() { public abstract boolean put(Object item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. + * A specialized variant of {@link #put(Object)} that only supports {@code String} items. */ - public abstract boolean putString(String str); + public abstract boolean putString(String item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put long. + * A specialized variant of {@link #put(Object)} that only supports {@code long} items. */ - public abstract boolean putLong(long l); + public abstract boolean putLong(long item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + * A specialized variant of {@link #put(Object)} that only supports byte array items. */ - public abstract boolean putBinary(byte[] bytes); + public abstract boolean putBinary(byte[] item); /** * Determines whether a given bloom filter is compatible with this bloom filter. For two @@ -137,38 +133,36 @@ int getVersionNumber() { public abstract boolean mightContain(Object item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 - * string. + * A specialized variant of {@link #mightContain(Object)} that only tests {@code String} items. */ - public abstract boolean mightContainString(String str); + public abstract boolean mightContainString(String item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + * A specialized variant of {@link #mightContain(Object)} that only tests {@code long} items. */ - public abstract boolean mightContainLong(long l); + public abstract boolean mightContainLong(long item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte - * array. + * A specialized variant of {@link #mightContain(Object)} that only tests byte array items. */ - public abstract boolean mightContainBinary(byte[] bytes); + public abstract boolean mightContainBinary(byte[] item); /** - * Writes out this {@link BloomFilter} to an output stream in binary format. - * It is the caller's responsibility to close the stream. + * Writes out this {@link BloomFilter} to an output stream in binary format. It is the caller's + * responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** - * Reads in a {@link BloomFilter} from an input stream. - * It is the caller's responsibility to close the stream. + * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close + * the stream. */ public static BloomFilter readFrom(InputStream in) throws IOException { return BloomFilterImpl.readFrom(in); } /** - * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the + * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. * * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. @@ -197,21 +191,31 @@ private static long optimalNumOfBits(long n, double p) { static final double DEFAULT_FPP = 0.03; /** - * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}. + * Creates a {@link BloomFilter} with the expected number of insertions and a default expected + * false positive probability of 3%. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. */ public static BloomFilter create(long expectedNumItems) { return create(expectedNumItems, DEFAULT_FPP); } /** - * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick - * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter. + * Creates a {@link BloomFilter} with the expected number of insertions and expected false + * positive probability. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. */ public static BloomFilter create(long expectedNumItems, double fpp) { - assert fpp > 0.0 : "False positive probability must be > 0.0"; - assert fpp < 1.0 : "False positive probability must be < 1.0"; - long numBits = optimalNumOfBits(expectedNumItems, fpp); - return create(expectedNumItems, numBits); + if (fpp <= 0D || fpp >= 1D) { + throw new IllegalArgumentException( + "False positive probability must be within range (0.0, 1.0)" + ); + } + + return create(expectedNumItems, optimalNumOfBits(expectedNumItems, fpp)); } /** @@ -219,9 +223,14 @@ public static BloomFilter create(long expectedNumItems, double fpp) { * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. */ public static BloomFilter create(long expectedNumItems, long numBits) { - assert expectedNumItems > 0 : "Expected insertions must be > 0"; - assert numBits > 0 : "number of bits must be > 0"; - int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); - return new BloomFilterImpl(numHashFunctions, numBits); + if (expectedNumItems <= 0) { + throw new IllegalArgumentException("Expected insertions must be positive"); + } + + if (numBits <= 0) { + throw new IllegalArgumentException("Number of bits must be positive"); + } + + return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 35107e0b389d..92c28bcb56a5 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,9 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter implements Serializable { +class BloomFilterImpl extends BloomFilter implements Serializable { private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { @@ -77,14 +78,14 @@ public boolean put(Object item) { } @Override - public boolean putString(String str) { - return putBinary(Utils.getBytesFromUTF8String(str)); + public boolean putString(String item) { + return putBinary(Utils.getBytesFromUTF8String(item)); } @Override - public boolean putBinary(byte[] bytes) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + public boolean putBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -100,14 +101,14 @@ public boolean putBinary(byte[] bytes) { } @Override - public boolean mightContainString(String str) { - return mightContainBinary(Utils.getBytesFromUTF8String(str)); + public boolean mightContainString(String item) { + return mightContainBinary(Utils.getBytesFromUTF8String(item)); } @Override - public boolean mightContainBinary(byte[] bytes) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + public boolean mightContainBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -124,14 +125,14 @@ public boolean mightContainBinary(byte[] bytes) { } @Override - public boolean putLong(long l) { + public boolean putLong(long item) { // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. // Note that `CountMinSketch` use a different strategy, it hash the input long element with // every i to produce n hash values. // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? - int h1 = Murmur3_x86_32.hashLong(l, 0); - int h2 = Murmur3_x86_32.hashLong(l, h1); + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -147,9 +148,9 @@ public boolean putLong(long l) { } @Override - public boolean mightContainLong(long l) { - int h1 = Murmur3_x86_32.hashLong(l, 0); - int h2 = Murmur3_x86_32.hashLong(l, h1); + public boolean mightContainLong(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -197,7 +198,7 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep throw new IncompatibleMergeException("Cannot merge null bloom filter"); } - if (!(other instanceof BloomFilter)) { + if (!(other instanceof BloomFilterImpl)) { throw new IncompatibleMergeException( "Cannot merge bloom filter of class " + other.getClass().getName() ); @@ -211,7 +212,8 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep if (this.numHashFunctions != that.numHashFunctions) { throw new IncompatibleMergeException( - "Cannot merge bloom filters with different number of hash functions"); + "Cannot merge bloom filters with different number of hash functions" + ); } this.bits.putAll(that.bits); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index f0aac5bb00df..48f98680f48c 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -22,7 +22,7 @@ import java.io.OutputStream; /** - * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in + * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in * sub-linear space. Currently, supported data types include: *

          *
        • {@link Byte}
        • @@ -31,8 +31,7 @@ *
        • {@link Long}
        • *
        • {@link String}
        • *
        - * Each {@link CountMinSketch} is initialized with a random seed, and a pair - * of parameters: + * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: *
          *
        1. relative error (or {@code eps}), and *
        2. confidence (or {@code delta}) @@ -49,16 +48,13 @@ *
        3. {@code w = ceil(-log(1 - confidence) / log(2))}
        4. *
      * - * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details, - * including proofs of the estimates and error bounds used in this implementation. - * * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { public enum Version { /** - * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * {@code CountMinSketch} binary format version 1. All values written in big-endian order: *
        *
      • Version number, always 1 (32 bit)
      • *
      • Total count of added items (64 bit)
      • @@ -172,14 +168,14 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException; /** - * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream. + * Writes out this {@link CountMinSketch} to an output stream in binary format. It is the caller's + * responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** - * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream. + * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to + * close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); @@ -188,6 +184,10 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { /** * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random * {@code seed}. + * + * @param depth depth of the Count-min Sketch, must be positive + * @param width width of the Count-min Sketch, must be positive + * @param seed random seed */ public static CountMinSketch create(int depth, int width, int seed) { return new CountMinSketchImpl(depth, width, seed); @@ -196,6 +196,10 @@ public static CountMinSketch create(int depth, int width, int seed) { /** * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence}, * and random {@code seed}. + * + * @param eps relative error, must be positive + * @param confidence confidence, must be positive and less than 1.0 + * @param seed random seed */ public static CountMinSketch create(double eps, double confidence, int seed) { return new CountMinSketchImpl(eps, confidence, seed); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index c0631c6778df..2acbb247b13c 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -42,6 +42,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { + if (depth <= 0 || width <= 0) { + throw new IllegalArgumentException("Depth and width must be both positive"); + } + this.depth = depth; this.width = width; this.eps = 2.0 / width; @@ -50,6 +54,14 @@ private CountMinSketchImpl() {} } CountMinSketchImpl(double eps, double confidence, int seed) { + if (eps <= 0D) { + throw new IllegalArgumentException("Relative error must be positive"); + } + + if (confidence <= 0D || confidence >= 1D) { + throw new IllegalArgumentException("Confidence must be within range (0.0, 1.0)"); + } + // 2/w = eps ; w = 2/eps // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) this.eps = eps; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java index a6b33313035b..feb601d44f39 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -19,7 +19,7 @@ import java.io.UnsupportedEncodingException; -public class Utils { +class Utils { public static byte[] getBytesFromUTF8String(String str) { try { return str.getBytes("utf-8"); From e38b0baa38c6894335f187eaa4c8ea5c02d4563b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 13:45:03 -0800 Subject: [PATCH 065/131] [SPARK-13055] SQLHistoryListener throws ClassCastException This is an existing issue uncovered recently by #10835. The reason for the exception was because the `SQLHistoryListener` gets all sorts of accumulators, not just the ones that represent SQL metrics. For example, the listener gets the `internal.metrics.shuffleRead.remoteBlocksFetched`, which is an Int, then it proceeds to cast the Int to a Long, which fails. The fix is to mark accumulators representing SQL metrics using some internal metadata. Then we can identify which ones are SQL metrics and only process those in the `SQLHistoryListener`. Author: Andrew Or Closes #10971 from andrewor14/fix-sql-history. --- .../scala/org/apache/spark/Accumulable.scala | 8 ++++ .../apache/spark/executor/TaskMetrics.scala | 4 +- .../spark/scheduler/AccumulableInfo.scala | 5 ++- .../apache/spark/scheduler/DAGScheduler.scala | 7 +--- .../org/apache/spark/util/JsonProtocol.scala | 6 ++- .../spark/executor/TaskMetricsSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 14 ++----- .../spark/scheduler/TaskSetManagerSuite.scala | 8 +--- .../apache/spark/util/JsonProtocolSuite.scala | 16 +++++--- .../sql/execution/metric/SQLMetrics.scala | 21 +++++++++- .../spark/sql/execution/ui/SQLListener.scala | 23 +++++++---- .../execution/metric/SQLMetricsSuite.scala | 24 +++++++++++- .../sql/execution/ui/SQLListenerSuite.scala | 38 ++++++++++++++++++- 13 files changed, 133 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index 52f572b63fa9..601b503d12c7 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -22,6 +22,7 @@ import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable import scala.reflect.ClassTag +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils @@ -187,6 +188,13 @@ class Accumulable[R, T] private ( */ private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) } + /** + * Create an [[AccumulableInfo]] representation of this [[Accumulable]] with the provided values. + */ + private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal, countFailedValues) + } + // Called by Java when deserializing an object private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 8d10bf588ef1..0a6ebcb3e029 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -323,8 +323,8 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { * field is always empty, since this represents the partial updates recorded in this task, * not the aggregated value across multiple tasks. */ - def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a => - new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues) + def accumulatorUpdates(): Seq[AccumulableInfo] = { + accums.map { a => a.toInfo(Some(a.localValue), None) } } // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 9d45fff9213c..cedacad44afe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi * @param value total accumulated value so far, maybe None if used on executors to describe a task * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed + * @param metadata internal metadata associated with this accumulator, if any */ @DeveloperApi case class AccumulableInfo private[spark] ( @@ -43,7 +44,9 @@ case class AccumulableInfo private[spark] ( update: Option[Any], // represents a partial update within a task value: Option[Any], private[spark] val internal: Boolean, - private[spark] val countFailedValues: Boolean) + private[spark] val countFailedValues: Boolean, + // TODO: use this to identify internal task metrics instead of encoding it in the name + private[spark] val metadata: Option[String] = None) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 897479b50010..ee0b8a1c95fd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1101,11 +1101,8 @@ class DAGScheduler( acc ++= partialValue // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name - stage.latestInfo.accumulables(id) = new AccumulableInfo( - id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues) - event.taskInfo.accumulables += new AccumulableInfo( - id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues) + stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) + event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value)) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index dc8070cf8aad..a2487eeb0483 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -290,7 +290,8 @@ private[spark] object JsonProtocol { ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~ ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~ ("Internal" -> accumulableInfo.internal) ~ - ("Count Failed Values" -> accumulableInfo.countFailedValues) + ("Count Failed Values" -> accumulableInfo.countFailedValues) ~ + ("Metadata" -> accumulableInfo.metadata) } /** @@ -728,7 +729,8 @@ private[spark] object JsonProtocol { val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) - new AccumulableInfo(id, name, update, value, internal, countFailedValues) + val metadata = (json \ "Metadata").extractOpt[String] + new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } /** diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 15be0b194ed8..67c4595ed192 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -551,8 +551,6 @@ private[spark] object TaskMetricsSuite extends Assertions { * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ - def makeInfo(a: Accumulable[_, _]): AccumulableInfo = { - new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues) - } + def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d9c71ec2eae7..62972a073821 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1581,12 +1581,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(Accumulators.get(acc1.id).isDefined) assert(Accumulators.get(acc2.id).isDefined) assert(Accumulators.get(acc3.id).isDefined) - val accInfo1 = new AccumulableInfo( - acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false) - val accInfo2 = new AccumulableInfo( - acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false) - val accInfo3 = new AccumulableInfo( - acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false) + val accInfo1 = acc1.toInfo(Some(15L), None) + val accInfo2 = acc2.toInfo(Some(13L), None) + val accInfo3 = acc3.toInfo(Some(18L), None) val accumUpdates = Seq(accInfo1, accInfo2, accInfo3) val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates) submit(new MyRDD(sc, 1, Nil), Array(0)) @@ -1954,10 +1951,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { val accumUpdates = reason match { - case Success => - task.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues) - } + case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) } case ef: ExceptionFailure => ef.accumUpdates case _ => Seq.empty[AccumulableInfo] } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index a2e74365641a..2c99dd5afb32 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -165,9 +165,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) - } + val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have @@ -186,9 +184,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => - task.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) - } + task.initialAccumulators.map { a => a.toInfo(Some(0L), None) } } // First three offers should all find tasks diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 57021d1d3d52..48951c316803 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -374,15 +374,18 @@ class JsonProtocolSuite extends SparkFunSuite { test("AccumulableInfo backward compatibility") { // "Internal" property of AccumulableInfo was added in 1.5.1 val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true) - val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Internal" }) + val accumulableInfoJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) + val oldJson = accumulableInfoJson.removeField({ _._1 == "Internal" }) val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) assert(!oldInfo.internal) // "Count Failed Values" property of AccumulableInfo was added in 2.0.0 - val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Count Failed Values" }) + val oldJson2 = accumulableInfoJson.removeField({ _._1 == "Count Failed Values" }) val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2) assert(!oldInfo2.countFailedValues) + // "Metadata" property of AccumulableInfo was added in 2.0.0 + val oldJson3 = accumulableInfoJson.removeField({ _._1 == "Metadata" }) + val oldInfo3 = JsonProtocol.accumulableInfoFromJson(oldJson3) + assert(oldInfo3.metadata.isEmpty) } test("ExceptionFailure backward compatibility: accumulator updates") { @@ -820,9 +823,10 @@ private[spark] object JsonProtocolSuite extends Assertions { private def makeAccumulableInfo( id: Int, internal: Boolean = false, - countFailedValues: Boolean = false): AccumulableInfo = + countFailedValues: Boolean = false, + metadata: Option[String] = None): AccumulableInfo = new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), - internal, countFailedValues) + internal, countFailedValues, metadata) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 950dc7816241..6b43d273fefd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.Utils /** @@ -27,9 +28,16 @@ import org.apache.spark.util.Utils * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, val param: SQLMetricParam[R, T]) + name: String, + val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name), internal = true) { + // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later + override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, Some(name), update, value, isInternal, countFailedValues, + Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + def reset(): Unit = { this.value = param.zero } @@ -73,6 +81,14 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr // Although there is a boxing here, it's fine because it's only called in SQLListener override def value: Long = _value + + // Needed for SQLListenerSuite + override def equals(other: Any): Boolean = { + other match { + case o: LongSQLMetricValue => value == o.value + case _ => false + } + } } /** @@ -126,6 +142,9 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam( private[sql] object SQLMetrics { + // Identifier for distinguishing SQL metrics from other accumulators + private[sql] val ACCUM_IDENTIFIER = "sql" + private def createLongMetric( sc: SparkContext, name: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 544606f1168b..835e7ba6c516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -23,7 +23,7 @@ import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric._ import org.apache.spark.ui.SparkUI @DeveloperApi @@ -314,14 +314,17 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } + +/** + * A [[SQLListener]] for rendering the SQL UI in the history server. + */ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) extends SQLListener(conf) { private var sqlTabAttached = false - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - // Do nothing + override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = { + // Do nothing; these events are not logged } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { @@ -329,9 +332,15 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { a => - val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L)) - a.copy(update = Some(newValue)) + taskEnd.taskInfo.accumulables.flatMap { a => + // Filter out accumulators that are not SQL metrics + // For now we assume all SQL metrics are Long's that have been JSON serialized as String's + if (a.metadata.exists(_ == SQLMetrics.ACCUM_IDENTIFIER)) { + val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L)) + Some(a.copy(update = Some(newValue))) + } else { + None + } }, finishTask = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503c2..2260e4870299 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{JsonProtocol, Utils} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { @@ -356,6 +356,28 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("metrics can be loaded by history server") { + val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam) + metric += 10L + val metricInfo = metric.toInfo(Some(metric.localValue), None) + metricInfo.update match { + case Some(v: LongSQLMetricValue) => assert(v.value === 10L) + case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}") + case _ => fail("metric update is missing") + } + assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + // After serializing to JSON, the original value type is lost, but we can still + // identify that it's a SQL metric from the metadata + val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo) + val metricInfoDeser = JsonProtocol.accumulableInfoFromJson(metricInfoJson) + metricInfoDeser.update match { + case Some(v: String) => assert(v.toLong === 10L) + case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}") + case _ => fail("deserialized metric update is missing") + } + assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2c408c887847..085e4a49a57e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.ui.SparkUI class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ @@ -335,8 +336,43 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } + test("SPARK-13055: history listener only tracks SQL metrics") { + val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI])) + // We need to post other events for the listener to track our accumulators. + // These are largely just boilerplate unrelated to what we're trying to test. + val df = createTestDataFrame + val executionStart = SparkListenerSQLExecutionStart( + 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0) + val stageInfo = createStageInfo(0, 0) + val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0)) + val stageSubmitted = SparkListenerStageSubmitted(stageInfo) + // This task has both accumulators that are SQL metrics and accumulators that are not. + // The listener should only track the ones that are actually SQL metrics. + val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella") + val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") + val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val taskInfo = createTaskInfo(0, 0) + taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) + val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) + listener.onOtherEvent(executionStart) + listener.onJobStart(jobStart) + listener.onStageSubmitted(stageSubmitted) + // Before SPARK-13055, this throws ClassCastException because the history listener would + // assume that the accumulator value is of type Long, but this may not be true for + // accumulators that are not SQL metrics. + listener.onTaskEnd(taskEnd) + val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics => + stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates) + } + // Listener tracks only SQL metrics, not other accumulators + assert(trackedAccums.size === 1) + assert(trackedAccums.head === sqlMetricInfo) + } + } + class SQLListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { From 2cbc412821641cf9446c0621ffa1976bd7fc4fa1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 29 Jan 2016 16:57:34 -0800 Subject: [PATCH 066/131] [SPARK-13076][SQL] Rename ClientInterface -> HiveClient And ClientWrapper -> HiveClientImpl. I have some followup pull requests to introduce a new internal catalog, and I think this new naming reflects better the functionality of the two classes. Author: Reynold Xin Closes #10981 from rxin/SPARK-13076. --- .../apache/spark/sql/execution/commands.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 10 +++++----- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- ...ClientInterface.scala => HiveClient.scala} | 10 +++++----- ...ientWrapper.scala => HiveClientImpl.scala} | 20 +++++++++---------- .../spark/sql/hive/client/HiveShim.scala | 5 ++--- .../hive/client/IsolatedClientLoader.scala | 18 ++++++++--------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- .../apache/spark/sql/hive/test/TestHive.scala | 4 ++-- .../spark/sql/hive/client/VersionsSuite.scala | 4 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 12 files changed, 41 insertions(+), 42 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/client/{ClientInterface.scala => HiveClient.scala} (95%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/client/{ClientWrapper.scala => HiveClientImpl.scala} (97%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 703e4643cbd2..c6adb583f931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -404,7 +404,7 @@ case class DescribeFunction( result } - case None => Seq(Row(s"Function: $functionName is not found.")) + case None => Seq(Row(s"Function: $functionName not found.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa30e..2b821c1056f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -84,7 +84,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") } test("SPARK-6743: no columns from cache") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1797ea54f250..05863ae18350 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -79,8 +79,8 @@ class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, listener: SQLListener, - @transient private val execHive: ClientWrapper, - @transient private val metaHive: ClientInterface, + @transient private val execHive: HiveClientImpl, + @transient private val metaHive: HiveClient, isRootContext: Boolean) extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { self => @@ -193,7 +193,7 @@ class HiveContext private[hive]( * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient - protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) { + protected[hive] lazy val executionHive: HiveClientImpl = if (execHive != null) { execHive } else { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") @@ -203,7 +203,7 @@ class HiveContext private[hive]( config = newTemporaryConfiguration(useInMemoryDerby = true), isolationOn = false, baseClassLoader = Utils.getContextOrSparkClassLoader) - loader.createClient().asInstanceOf[ClientWrapper] + loader.createClient().asInstanceOf[HiveClientImpl] } /** @@ -222,7 +222,7 @@ class HiveContext private[hive]( * in the hive-site.xml file. */ @transient - protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) { + protected[hive] lazy val metadataHive: HiveClient = if (metaHive != null) { metaHive } else { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 848aa4ec6fe5..61d0d6759ff7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -96,7 +96,7 @@ private[hive] object HiveSerDe { } } -private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) +private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) extends Catalog with Logging { val conf = hive.conf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala similarity index 95% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 4eec3fef7408..f681cc67041a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -60,9 +60,9 @@ private[hive] case class HiveTable( viewText: Option[String] = None) { @transient - private[client] var client: ClientInterface = _ + private[client] var client: HiveClient = _ - private[client] def withClient(ci: ClientInterface): this.type = { + private[client] def withClient(ci: HiveClient): this.type = { client = ci this } @@ -85,7 +85,7 @@ private[hive] case class HiveTable( * internal and external classloaders for a given version of Hive and thus must expose only * shared classes. */ -private[hive] trait ClientInterface { +private[hive] trait HiveClient { /** Returns the Hive Version of this client. */ def version: HiveVersion @@ -184,8 +184,8 @@ private[hive] trait ClientInterface { /** Add a jar into class loader */ def addJar(path: String): Unit - /** Return a ClientInterface as new session, that will share the class loader and Hive client */ - def newSession(): ClientInterface + /** Return a [[HiveClient]] as new session, that will share the class loader and Hive client */ + def newSession(): HiveClient /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ def withHiveState[A](f: => A): A diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala similarity index 97% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 5307e924e7e5..cf1ff55c96fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.{CircularBuffer, Utils} * A class that wraps the HiveClient and converts its responses to externally visible classes. * Note that this class is typically loaded with an internal classloader for each instantiation, * allowing it to interact directly with a specific isolated version of Hive. Loading this class - * with the isolated classloader however will result in it only being visible as a ClientInterface, - * not a ClientWrapper. + * with the isolated classloader however will result in it only being visible as a [[HiveClient]], + * not a [[HiveClientImpl]]. * * This class needs to interact with multiple versions of Hive, but will always be compiled with * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility @@ -55,14 +55,14 @@ import org.apache.spark.util.{CircularBuffer, Utils} * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. * @param initClassLoader the classloader used when creating the `state` field of - * this ClientWrapper. + * this [[HiveClientImpl]]. */ -private[hive] class ClientWrapper( +private[hive] class HiveClientImpl( override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) - extends ClientInterface + extends HiveClient with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @@ -77,7 +77,7 @@ private[hive] class ClientWrapper( case hive.v1_2 => new Shim_v1_2() } - // Create an internal session state for this ClientWrapper. + // Create an internal session state for this HiveClientImpl. val state = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. @@ -160,7 +160,7 @@ private[hive] class ClientWrapper( case e: Exception if causedByThrift(e) => caughtException = e logWarning( - "HiveClientWrapper got thrift exception, destroying client and retrying " + + "HiveClient got thrift exception, destroying client and retrying " + s"(${retryLimit - numTries} tries remaining)", e) clientLoader.cachedHive = null Thread.sleep(retryDelayMillis) @@ -199,7 +199,7 @@ private[hive] class ClientWrapper( */ def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - // Set the thread local metastore client to the client associated with this ClientWrapper. + // Set the thread local metastore client to the client associated with this HiveClientImpl. Hive.set(client) // The classloader in clientLoader could be changed after addJar, always use the latest // classloader @@ -521,8 +521,8 @@ private[hive] class ClientWrapper( runSqlHive(s"ADD JAR $path") } - def newSession(): ClientWrapper = { - clientLoader.createClient().asInstanceOf[ClientWrapper] + def newSession(): HiveClientImpl = { + clientLoader.createClient().asInstanceOf[HiveClientImpl] } def reset(): Unit = withHiveState { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index ca636b0265d4..70c10be25be9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegralType, StringType} /** - * A shim that defines the interface between ClientWrapper and the underlying Hive library used to - * talk to the metastore. Each Hive version has its own implementation of this class, defining + * A shim that defines the interface between [[HiveClientImpl]] and the underlying Hive library used + * to talk to the metastore. Each Hive version has its own implementation of this class, defining * version-specific version of needed functions. * * The guideline for writing shims is: @@ -52,7 +52,6 @@ private[client] sealed abstract class Shim { /** * Set the current SessionState to the given SessionState. Also, set the context classloader of * the current thread to the one set in the HiveConf of this given `state`. - * @param state */ def setCurrentSessionState(state: SessionState): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 010051d255fd..dca7396ee1ab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -124,15 +124,15 @@ private[hive] object IsolatedClientLoader extends Logging { } /** - * Creates a Hive `ClientInterface` using a classloader that works according to the following rules: + * Creates a [[HiveClient]] using a classloader that works according to the following rules: * - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` - * allowing the results of calls to the `ClientInterface` to be visible externally. + * allowing the results of calls to the [[HiveClient]] to be visible externally. * - Hive classes: new instances are loaded from `execJars`. These classes are not * accessible externally due to their custom loading. - * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`. + * - [[HiveClientImpl]]: a new copy is created for each instance of `IsolatedClassLoader`. * This new instance is able to see a specific version of hive without using reflection. Since * this is a unique instance, it is not visible externally other than as a generic - * `ClientInterface`, unless `isolationOn` is set to `false`. + * [[HiveClient]], unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures * that are not compatible across versions. @@ -179,7 +179,7 @@ private[hive] class IsolatedClientLoader( /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith(classOf[ClientWrapper].getName) || + name.startsWith(classOf[HiveClientImpl].getName) || name.startsWith(classOf[Shim].getName) || barrierPrefixes.exists(name.startsWith) @@ -233,9 +233,9 @@ private[hive] class IsolatedClientLoader( } /** The isolated client interface to Hive. */ - private[hive] def createClient(): ClientInterface = { + private[hive] def createClient(): HiveClient = { if (!isolationOn) { - return new ClientWrapper(version, config, baseClassLoader, this) + return new HiveClientImpl(version, config, baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -244,10 +244,10 @@ private[hive] class IsolatedClientLoader( try { classLoader - .loadClass(classOf[ClientWrapper].getName) + .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head .newInstance(version, config, classLoader, this) - .asInstanceOf[ClientInterface] + .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => if (e.getCause().isInstanceOf[NoClassDefFoundError]) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 56cab1aee89d..d5ed838ca4b1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -38,13 +38,13 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ private[hive] class HiveFunctionRegistry( underlying: analysis.FunctionRegistry, - executionHive: ClientWrapper) + executionHive: HiveClientImpl) extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a33223af2437..246108e0d0e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -458,7 +458,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) } -private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper) +private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl) extends HiveFunctionRegistry(fr, client) { private val removedFunctions = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff10a251f3b4..1344a2cc4bd3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * A simple set of tests that call the methods of a [[HiveClient]], loading different version * of hive from maven central. These tests are simple in that they are mostly just testing to make * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. @@ -101,7 +101,7 @@ class VersionsSuite extends SparkFunSuite with Logging { private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") - private var client: ClientInterface = null + private var client: HiveClient = null versions.foreach { version => test(s"$version: create client") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 0d62d799c8dc..1ada2e325bda 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -199,7 +199,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") checkExistence(sql("describe functioN `~`"), true, "Function: ~", From e6ceac49a311faf3413acda57a6612fe806adf90 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 17:59:41 -0800 Subject: [PATCH 067/131] [SPARK-13096][TEST] Fix flaky verifyPeakExecutionMemorySet Previously we would assert things before all events are guaranteed to have been processed. To fix this, just block until all events are actually processed, i.e. until the listener queue is empty. https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/79/testReport/junit/org.apache.spark.util.collection/ExternalAppendOnlyMapSuite/spilling/ Author: Andrew Or Closes #10990 from andrewor14/accum-suite-less-flaky. --- core/src/test/scala/org/apache/spark/AccumulatorSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 11c97d7d9a44..b8f2b96d7088 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -307,6 +307,8 @@ private[spark] object AccumulatorSuite { val listener = new SaveInfoListener sc.addSparkListener(listener) testBody + // wait until all events have been processed before proceeding to assert things + sc.listenerBus.waitUntilEmpty(10 * 1000) val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) val isSet = accums.exists { a => a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) From 70e69fc4dd619654f5d24b8b84f6a94f7705c59b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 18:00:49 -0800 Subject: [PATCH 068/131] [SPARK-13088] Fix DAG viz in latest version of chrome Apparently chrome removed `SVGElement.prototype.getTransformToElement`, which is used by our JS library dagre-d3 when creating edges. The real diff can be found here: https://github.com/andrewor14/dagre-d3/commit/7d6c0002e4c74b82a02c5917876576f71e215590, which is taken from the fix in the main repo: https://github.com/cpettitt/dagre-d3/commit/1ef067f1c6ad2e0980f6f0ca471bce998784b7b2 Upstream issue: https://github.com/cpettitt/dagre-d3/issues/202 Author: Andrew Or Closes #10986 from andrewor14/fix-dag-viz. --- .../org/apache/spark/ui/static/dagre-d3.min.js | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index 2d9262b972a5..6fe8136c87ae 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -1,4 +1,5 @@ -/* This is a custom version of dagre-d3 on top of v0.4.3. The full list of commits can be found at http://github.com/andrewor14/dagre-d3/ */!function(e){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=e();else if("function"==typeof define&&define.amd)define([],e);else{var f;"undefined"!=typeof window?f=window:"undefined"!=typeof global?f=global:"undefined"!=typeof self&&(f=self),f.dagreD3=e()}}(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ -parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdx0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){ +normalize.run(g)});time(" parentDummyChains",function(){parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; +})}function enterEdge(t,g,edge){var v=edge.v,w=edge.w;if(!g.hasEdge(v,w)){v=edge.w;w=edge.v}var vLabel=t.node(v),wLabel=t.node(w),tailLabel=vLabel,flip=false;if(vLabel.lim>wLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){ +return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index Date: Fri, 29 Jan 2016 18:03:04 -0800 Subject: [PATCH 069/131] [SPARK-13071] Coalescing HadoopRDD overwrites existing input metrics This issue is causing tests to fail consistently in master with Hadoop 2.6 / 2.7. This is because for Hadoop 2.5+ we overwrite existing values of `InputMetrics#bytesRead` in each call to `HadoopRDD#compute`. In the case of coalesce, e.g. ``` sc.textFile(..., 4).coalesce(2).count() ``` we will call `compute` multiple times in the same task, overwriting `bytesRead` values from previous calls to `compute`. For a regression test, see `InputOutputMetricsSuite.input metrics for old hadoop with coalesce`. I did not add a new regression test because it's impossible without significant refactoring; there's a lot of existing duplicate code in this corner of Spark. This was caused by #10835. Author: Andrew Or Closes #10973 from andrewor14/fix-input-metrics-coalesce. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 7 ++++++- .../src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 7 ++++++- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 3204e6adceca..e2ebd7f00d0d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,6 +215,7 @@ class HadoopRDD[K, V]( // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { @@ -230,9 +231,13 @@ class HadoopRDD[K, V]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 4d2816e335fe..e71d3405c0ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -130,6 +130,7 @@ class NewHadoopRDD[K, V]( val conf = getConf val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -139,9 +140,13 @@ class NewHadoopRDD[K, V]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index edd87c2d8ed0..9703b16c86f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -127,6 +127,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( val conf = getConf(isDriverSide = false) val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { @@ -142,9 +143,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } From e6a02c66d53f59ba2d5c1548494ae80a385f9f5c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 20:16:11 -0800 Subject: [PATCH 070/131] [SPARK-12914] [SQL] generate aggregation with grouping keys This PR add support for grouping keys for generated TungstenAggregate. Spilling and performance improvements for BytesToBytesMap will be done by followup PR. Author: Davies Liu Closes #10855 from davies/gen_keys. --- .../expressions/codegen/CodeGenerator.scala | 47 ++++ .../codegen/GenerateMutableProjection.scala | 27 +- .../sql/execution/BufferedRowIterator.java | 6 +- .../aggregate/TungstenAggregate.scala | 238 ++++++++++++++++-- .../BenchmarkWholeStageCodegen.scala | 119 ++++++++- .../execution/WholeStageCodegenSuite.scala | 9 + 6 files changed, 393 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f..21f9198073d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,20 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ + def addReferenceObj(name: String, obj: Any, className: String = null): String = { + val term = freshName(name) + val idx = references.length + references += obj + val clsName = Option(className).getOrElse(obj.getClass.getName) + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. @@ -198,6 +212,39 @@ class CodegenContext { } } + /** + * Update a column in MutableRow from ExprCode. + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (dataType.isInstanceOf[DecimalType]) { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + ${setColumn(row, dataType, ordinal, "null")}; + } + """ + } else { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + $row.setNullAt($ordinal); + } + """ + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b..5b4dc8df8622 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val updates = validExpr.zip(index).map { case (e, i) => - if (e.nullable) { - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } else { - s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } - } else { - s""" - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - """ - } - + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index b1bbb1da10a3..6acf70dbbad0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution; +import java.io.IOException; + import scala.collection.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -34,7 +36,7 @@ public class BufferedRowIterator { // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); - public boolean hasNext() { + public boolean hasNext() throws IOException { if (currentRow == null) { processNext(); } @@ -56,7 +58,7 @@ public void setInput(Iterator iter) { * * After it's called, if currentRow is still null, it means no more rows left. */ - protected void processNext() { + protected void processNext() throws IOException { if (input.hasNext()) { currentRow = input.next(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index ff2f38bfd910..57db7262fdaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -114,22 +116,38 @@ case class TungstenAggregate( } } + // all the mode of aggregate expressions + private val modes = aggregateExpressions.map(_.mode).distinct + override def supportCodegen: Boolean = { - groupingExpressions.isEmpty && - // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - // The variables used as aggregation buffer - private var bufVars: Seq[ExprCode] = _ - - private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { child.asInstanceOf[CodegenSupport].upstream() } protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -176,10 +194,10 @@ case class TungstenAggregate( (resultVars, resultVars.map(_.code).mkString("\n")) } - val doAgg = ctx.freshName("doAgg") + val doAgg = ctx.freshName("doAggregateWithoutKey") ctx.addNewFunction(doAgg, s""" - | private void $doAgg() { + | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | @@ -200,7 +218,7 @@ case class TungstenAggregate( """.stripMargin } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output @@ -212,7 +230,6 @@ case class TungstenAggregate( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) @@ -232,6 +249,199 @@ case class TungstenAggregate( """.stripMargin } + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) + private val bufferSchema = StructType.fromAttributes(bufferAttributes) + + // The name for HashMap + private var hashMapTerm: String = _ + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + + /** + * Update peak execution memory, called in generated Java class. + */ + def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val metrics = TaskContext.get().taskMetrics() + metrics.incPeakExecutionMemory(mapMemory) + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + ctx.INPUT_ROW = bufferTerm + val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttributes).gen(ctx) + } + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).gen(ctx) + } + s""" + ${keyVars.map(_.code).mkString("\n")} + ${bufferVars.map(_.code).mkString("\n")} + ${aggResults.map(_.code).mkString("\n")} + ${resultVars.map(_.code).mkString("\n")} + + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).gen(ctx) + } + s""" + ${eval.map(_.code).mkString("\n")} + ${consume(ctx, eval)} + """ + } + + val doAgg = ctx.freshName("doAggregateWithKeys") + ctx.addNewFunction(doAgg, + s""" + private void $doAgg() throws java.io.IOException { + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + $iterTerm = $hashMapTerm.iterator(); + } + """) + + s""" + if (!$initAgg) { + $initAgg = true; + $doAgg(); + } + + // output the result + while ($iterTerm.next()) { + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + } + + $thisPlan.updatePeakMemory($hashMapTerm); + $hashMapTerm.free(); + """ + } + + private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + val keyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val key = keyCode.value + val buffer = ctx.freshName("aggBuffer") + + // only have DeclarativeAggregate + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + + val inputAttr = bufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input + ctx.INPUT_ROW = buffer + // TODO: support subexpression elimination + val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val updates = evals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) + } + + s""" + // generate grouping key + ${keyCode.code} + UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } + + // evaluate aggregate function + ${evals.map(_.code).mkString("\n")} + // update aggregate buffer + ${updates.mkString("\n")} + """ + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index c4aad398bfa5..2f09c8a114bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** @@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" */ class BenchmarkWholeStageCodegen extends SparkFunSuite { - def testWholeStage(values: Int): Unit = { - val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - val sc = SparkContext.getOrCreate(conf) - val sqlContext = SQLContext.getOrCreate(sc) + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) - val benchmark = new Benchmark("Single Int Column Scan", values) + def testWholeStage(values: Int): Unit = { + val benchmark = new Benchmark("rang/filter/aggregate", values) - benchmark.addCase("Without whole stage codegen") { iter => + benchmark.addCase("Without codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "false") sqlContext.range(values).filter("(id & 1) = 1").count() } - benchmark.addCase("With whole stage codegen") { iter => + benchmark.addCase("With codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "true") sqlContext.range(values).filter("(id & 1) = 1").count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- - Without whole stage codegen 7775.53 26.97 1.00 X - With whole stage codegen 342.15 612.94 22.73 X + Without codegen 7775.53 26.97 1.00 X + With codegen 342.15 612.94 22.73 X */ benchmark.run() } - ignore("benchmark") { - testWholeStage(1024 * 1024 * 200) + def testAggregateWithKey(values: Int): Unit = { + val benchmark = new Benchmark("Aggregate with keys", values) + + benchmark.addCase("Aggregate w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + benchmark.addCase(s"Aggregate w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Aggregate w/o codegen 4254.38 4.93 1.00 X + Aggregate w codegen 2661.45 7.88 1.60 X + */ + benchmark.run() + } + + def testBytesToBytesMap(values: Int): Unit = { + val benchmark = new Benchmark("BytesToBytesMap", values) + + benchmark.addCase("hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < values) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + s += h + i += 1 + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + while (i < values) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + if (loc.isDefined) { + value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, + loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { + loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + } + } + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + hash 662.06 79.19 1.00 X + BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X + BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + */ + benchmark.run() + } + + test("benchmark") { + // testWholeStage(1024 * 1024 * 200) + // testAggregateWithKey(20 << 20) + // testBytesToBytesMap(1024 * 1024 * 50) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 300788c88ab2..c2516509dfbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { + val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + } } From dab246f7e4664d36073ec49d9df8a11c5e998cdb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 23:37:51 -0800 Subject: [PATCH 071/131] [SPARK-13098] [SQL] remove GenericInternalRowWithSchema This class is only used for serialization of Python DataFrame. However, we don't require internal row there, so `GenericRowWithSchema` can also do the job. Author: Wenchen Fan Closes #10992 from cloud-fan/python. --- .../spark/sql/catalyst/expressions/rows.scala | 12 ------------ .../org/apache/spark/sql/execution/python.scala | 13 +++++-------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 387d979484f2..be6b2530ef39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -233,18 +233,6 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override def copy(): GenericInternalRow = this } -/** - * This is used for serialization of Python DataFrame - */ -class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) - extends GenericInternalRow(values) { - - /** No-arg constructor for serialization. */ - protected def this() = this(null, null) - - def fieldIndex(name: String): Int = schema.fieldIndex(name) -} - class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index e3a016e18db8..bf62bb05c3d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -143,7 +143,7 @@ object EvaluatePython { values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) i += 1 } - new GenericInternalRowWithSchema(values, struct) + new GenericRowWithSchema(values, struct) case (a: ArrayData, array: ArrayType) => val values = new java.util.ArrayList[Any](a.numElements()) @@ -199,10 +199,7 @@ object EvaluatePython { case (c: Long, TimestampType) => c - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) + case (c, StringType) => UTF8String.fromString(c.toString) case (c: String, BinaryType) => c.getBytes("utf-8") case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c @@ -263,11 +260,11 @@ object EvaluatePython { } /** - * Pickler for InternalRow + * Pickler for external row. */ private class RowPickler extends IObjectPickler { - private val cls = classOf[GenericInternalRowWithSchema] + private val cls = classOf[GenericRowWithSchema] // register this to Pickler and Unpickler def register(): Unit = { @@ -282,7 +279,7 @@ object EvaluatePython { } else { // it will be memorized by Pickler to save some bytes pickler.save(this) - val row = obj.asInstanceOf[GenericInternalRowWithSchema] + val row = obj.asInstanceOf[GenericRowWithSchema] // schema should always be same object for memoization pickler.save(row.schema) out.write(Opcodes.TUPLE1) From 289373b28cd2546165187de2e6a9185a1257b1e7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 30 Jan 2016 00:20:28 -0800 Subject: [PATCH 072/131] [SPARK-6363][BUILD] Make Scala 2.11 the default Scala version This patch changes Spark's build to make Scala 2.11 the default Scala version. To be clear, this does not mean that Spark will stop supporting Scala 2.10: users will still be able to compile Spark for Scala 2.10 by following the instructions on the "Building Spark" page; however, it does mean that Scala 2.11 will be the default Scala version used by our CI builds (including pull request builds). The Scala 2.11 compiler is faster than 2.10, so I think we'll be able to look forward to a slight speedup in our CI builds (it looks like it's about 2X faster for the Maven compile-only builds, for instance). After this patch is merged, I'll update Jenkins to add new compile-only jobs to ensure that Scala 2.10 compilation doesn't break. Author: Josh Rosen Closes #10608 from JoshRosen/SPARK-6363. --- assembly/pom.xml | 4 +-- common/sketch/pom.xml | 4 +-- core/pom.xml | 4 +-- dev/create-release/release-build.sh | 14 ++++----- dev/deps/spark-deps-hadoop-2.2 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.3 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.4 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.6 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.7 | 31 +++++++++---------- docker-integration-tests/pom.xml | 4 +-- docs/_plugins/copy_api_dirs.rb | 2 +- docs/building-spark.md | 10 +++--- examples/pom.xml | 4 +-- external/akka/pom.xml | 4 +-- external/flume-assembly/pom.xml | 4 +-- external/flume-sink/pom.xml | 4 +-- external/flume/pom.xml | 4 +-- external/kafka-assembly/pom.xml | 4 +-- external/kafka/pom.xml | 4 +-- external/mqtt-assembly/pom.xml | 4 +-- external/mqtt/pom.xml | 4 +-- external/twitter/pom.xml | 4 +-- external/zeromq/pom.xml | 4 +-- extras/java8-tests/pom.xml | 4 +-- extras/kinesis-asl-assembly/pom.xml | 4 +-- extras/kinesis-asl/pom.xml | 4 +-- extras/spark-ganglia-lgpl/pom.xml | 4 +-- graphx/pom.xml | 4 +-- launcher/pom.xml | 4 +-- mllib/pom.xml | 4 +-- network/common/pom.xml | 4 +-- network/shuffle/pom.xml | 4 +-- network/yarn/pom.xml | 4 +-- pom.xml | 8 ++--- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 6 ++++ project/SparkBuild.scala | 12 +++---- repl/pom.xml | 8 ++--- .../scala/org/apache/spark/repl/Main.scala | 9 +++++- .../org/apache/spark/repl/ReplSuite.scala | 7 +---- sql/catalyst/pom.xml | 13 ++------ sql/core/pom.xml | 6 ++-- sql/hive-thriftserver/pom.xml | 4 +-- sql/hive/pom.xml | 4 +-- streaming/pom.xml | 4 +-- tags/pom.xml | 4 +-- tools/pom.xml | 4 +-- unsafe/pom.xml | 4 +-- yarn/pom.xml | 4 +-- 49 files changed, 186 insertions(+), 194 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 6c79f9189787..477d4931c3a8 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-assembly_2.10 + spark-assembly_2.11 Spark Project Assembly http://spark.apache.org/ pom diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2cafe8c548f5..442043cb5116 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sketch_2.10 + spark-sketch_2.11 jar Spark Project Sketch http://spark.apache.org/ diff --git a/core/pom.xml b/core/pom.xml index 0ab170e028ab..be40d9936afd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-core_2.10 + spark-core_2.11 core diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 00bf81120df6..2fd7fcc39ea2 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -134,9 +134,9 @@ if [[ "$1" == "package" ]]; then cd spark-$SPARK_VERSION-bin-$NAME - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 + # TODO There should probably be a flag to make-distribution to allow 2.10 support + if [[ $FLAGS == *scala-2.10* ]]; then + ./dev/change-scala-version.sh 2.10 fi export ZINC_PORT=$ZINC_PORT @@ -228,8 +228,8 @@ if [[ "$1" == "publish-snapshot" ]]; then $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver deploy - ./dev/change-scala-version.sh 2.11 - $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \ -DskipTests $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process @@ -266,9 +266,9 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver clean install - ./dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.10 - $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ -DskipTests $PUBLISH_PROFILES clean install # Clean-up Zinc nailgun process diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 4d9937c5cbc3..3a14499d9b4d 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -14,13 +14,13 @@ avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -86,10 +86,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar javax.servlet-3.1.jar @@ -111,15 +110,14 @@ jets3t-0.7.1.jar jettison-1.1.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -158,19 +156,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index fd659ee20df1..615836b3d3b7 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -102,15 +101,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -149,19 +147,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index afae3deb9ada..f275226f1d08 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -103,15 +102,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -150,19 +148,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 5a6460136a3a..21432a16e365 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -109,15 +108,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -156,19 +154,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 70083e7f3d16..20e09cd00263 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -109,15 +108,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -157,19 +155,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index 78b638ecfa63..833ca29cd821 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml - spark-docker-integration-tests_2.10 + spark-docker-integration-tests_2.11 jar Spark Project Docker Integration Tests http://spark.apache.org/ diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 174c202e3791..f926d67e6bea 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.10/unidoc" + source = "../target/scala-2.11/unidoc" dest = "api/scala" puts "Making directory " + dest diff --git a/docs/building-spark.md b/docs/building-spark.md index e1abcf1be501..975e1b295c8a 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -114,13 +114,11 @@ By default Spark will build with Hive 0.13.1 bindings. mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package {% endhighlight %} -# Building for Scala 2.11 -To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: +# Building for Scala 2.10 +To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property: - ./dev/change-scala-version.sh 2.11 - mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package - -Spark does not yet support its JDBC component for Scala 2.11. + ./dev/change-scala-version.sh 2.10 + mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package # Spark Tests in Maven diff --git a/examples/pom.xml b/examples/pom.xml index 9437cee2abfd..82baa9085b4f 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-examples_2.10 + spark-examples_2.11 examples diff --git a/external/akka/pom.xml b/external/akka/pom.xml index 06c8e8aaabd8..bbe644e3b32b 100644 --- a/external/akka/pom.xml +++ b/external/akka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-akka_2.10 + spark-streaming-akka_2.11 streaming-akka diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index b2c377fe4cc9..ac15b93c048d 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-assembly_2.10 + spark-streaming-flume-assembly_2.11 jar Spark Project External Flume Assembly http://spark.apache.org/ diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 4b6485ee0a71..e4effe158c82 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-sink_2.10 + spark-streaming-flume-sink_2.11 streaming-flume-sink diff --git a/external/flume/pom.xml b/external/flume/pom.xml index a79656c6f7d9..d650dd034d63 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume_2.10 + spark-streaming-flume_2.11 streaming-flume diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0c466b3c4ac3..62818f5e8f43 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka-assembly_2.10 + spark-streaming-kafka-assembly_2.11 jar Spark Project External Kafka Assembly http://spark.apache.org/ diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 5180ab6dbafb..68d52e9339b3 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka_2.10 + spark-streaming-kafka_2.11 streaming-kafka diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index c4a1ae26ea69..ac2a3f65ed2f 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-mqtt-assembly_2.10 + spark-streaming-mqtt-assembly_2.11 jar Spark Project External MQTT Assembly http://spark.apache.org/ diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index d3a2bf5825b0..d0d968782c7f 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-mqtt_2.10 + spark-streaming-mqtt_2.11 streaming-mqtt diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 7b628b09ea6a..5d4053afcbba 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-twitter_2.10 + spark-streaming-twitter_2.11 streaming-twitter diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7781aaeed9e0..f16bc0f31974 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-zeromq_2.10 + spark-streaming-zeromq_2.11 streaming-zeromq diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4dfe3b654df1..0ad9c5303a36 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - java8-tests_2.10 + java8-tests_2.11 pom Spark Project Java8 Tests POM diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 601080c2e6fb..d1c38c7ca5d6 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kinesis-asl-assembly_2.10 + spark-streaming-kinesis-asl-assembly_2.11 jar Spark Project Kinesis Assembly http://spark.apache.org/ diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 20e2c5e0ffbe..935155eb5d36 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -19,14 +19,14 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kinesis-asl_2.10 + spark-streaming-kinesis-asl_2.11 jar Spark Kinesis Integration diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index b046a10a04d5..bfb92791de3d 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -19,14 +19,14 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-ganglia-lgpl_2.10 + spark-ganglia-lgpl_2.11 jar Spark Ganglia Integration diff --git a/graphx/pom.xml b/graphx/pom.xml index 388a0ef06a2b..1813f383cdcb 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-graphx_2.10 + spark-graphx_2.11 graphx diff --git a/launcher/pom.xml b/launcher/pom.xml index 135866cea2e7..ef731948826e 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-launcher_2.10 + spark-launcher_2.11 jar Spark Project Launcher http://spark.apache.org/ diff --git a/mllib/pom.xml b/mllib/pom.xml index 42af2b8b3e41..816f3f683038 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-mllib_2.10 + spark-mllib_2.11 mllib diff --git a/network/common/pom.xml b/network/common/pom.xml index eda2b7307088..bd507c2cb6c4 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-common_2.10 + spark-network-common_2.11 jar Spark Project Networking http://spark.apache.org/ diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index f9aa7e2dd1f4..810ec10ca05b 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_2.11 jar Spark Project Shuffle Streaming Service http://spark.apache.org/ diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a19cbb04b18c..a28785b16e1e 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-yarn_2.10 + spark-network-yarn_2.11 jar Spark Project YARN Shuffle Service http://spark.apache.org/ diff --git a/pom.xml b/pom.xml index fb7750602c42..d0387aca66d0 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 14 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT pom Spark Project Parent POM @@ -165,7 +165,7 @@ 3.2.2 2.10.5 - 2.10 + 2.11 ${scala.version} org.scala-lang 1.9.13 @@ -2456,7 +2456,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 2.10.5 @@ -2488,7 +2488,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 2.11.7 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 41856443af49..4adf64a5a0d8 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -95,7 +95,7 @@ object MimaBuild { // because spark-streaming-mqtt(1.6.0) depends on it. // Remove the setting on updating previousSparkVersion. val previousSparkVersion = "1.6.0" - val fullId = "spark-" + projectRef.project + "_2.10" + val fullId = "spark-" + projectRef.project + "_2.11" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a3ae4d2b730f..3748e07f88aa 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -220,6 +220,12 @@ object MimaExcludes { // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") + ) ++ Seq( + // SPARK-6363 Make Scala 2.11 the default Scala version + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") ) case v if v.startsWith("1.6") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4224a65a822b..550b5bad8a46 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -119,11 +119,11 @@ object SparkBuild extends PomBuild { v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - if (System.getProperty("scala-2.11") == "") { - // To activate scala-2.11 profile, replace empty property value to non-empty value + if (System.getProperty("scala-2.10") == "") { + // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.11", "true") + System.setProperty("scala-2.10", "true") } profiles } @@ -382,7 +382,7 @@ object OldDeps { lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" + val fullId = id + "_2.11" Some("org.apache.spark" % fullId % "1.2.0") } @@ -390,7 +390,7 @@ object OldDeps { name := "old-deps", scalaVersion := "2.10.5", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming-flume", "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-graphx", "spark-core").map(versionArtifact(_).get intransitive()) ) @@ -704,7 +704,7 @@ object Java8TestSettings { lazy val settings = Seq( javacJVMVersion := "1.8", // Targeting Java 8 bytecode is only supported in Scala 2.11.4 and higher: - scalacJVMVersion := (if (System.getProperty("scala-2.11") == "true") "1.8" else "1.7") + scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8") ) } diff --git a/repl/pom.xml b/repl/pom.xml index efc3dd452e32..0f396c9b809b 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-repl_2.10 + spark-repl_2.11 jar Spark Project REPL http://spark.apache.org/ @@ -159,7 +159,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 @@ -173,7 +173,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 scala-2.11/src/main/scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index bb3081d12938..07ba28bb0754 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -33,7 +33,8 @@ object Main extends Logging { var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ - var interp = new SparkILoop // this is a public var because tests reset it. + // this is a public var because tests reset it. + var interp: SparkILoop = _ private var hasErrors = false @@ -43,6 +44,12 @@ object Main extends Logging { } def main(args: Array[String]) { + doMain(args, new SparkILoop) + } + + // Visible for testing + private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { + interp = _interp val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 63f3688c9e61..b9ed79da421a 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -50,12 +50,7 @@ class ReplSuite extends SparkFunSuite { System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) System.setProperty("spark.master", master) - val interp = { - new SparkILoop(in, new PrintWriter(out)) - } - org.apache.spark.repl.Main.interp = interp - Main.main(Array("-classpath", classpath)) // call main - org.apache.spark.repl.Main.interp = null + Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) if (oldExecutorClasspath != null) { System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 76ca3f3bb1bf..c2ad9b99f3ac 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-catalyst_2.10 + spark-catalyst_2.11 jar Spark Project Catalyst http://spark.apache.org/ @@ -127,13 +127,4 @@ - - - - scala-2.10 - - !scala-2.11 - - - diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 4bb55f6b7f73..89e01fc01596 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql_2.10 + spark-sql_2.11 jar Spark Project SQL http://spark.apache.org/ @@ -44,7 +44,7 @@ org.apache.spark - spark-sketch_2.10 + spark-sketch_2.11 ${project.version} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 435e565f6345..c8d17bd46858 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive-thriftserver_2.10 + spark-hive-thriftserver_2.11 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index cd0c2aeb93a9..14cf9acf09d5 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive_2.10 + spark-hive_2.11 jar Spark Project Hive http://spark.apache.org/ diff --git a/streaming/pom.xml b/streaming/pom.xml index 39cbd0d00f95..7d409c5d3b07 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-streaming_2.10 + spark-streaming_2.11 streaming diff --git a/tags/pom.xml b/tags/pom.xml index 9e4610dae7a6..3e8e6f618287 100644 --- a/tags/pom.xml +++ b/tags/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-test-tags_2.10 + spark-test-tags_2.11 jar Spark Project Test Tags http://spark.apache.org/ diff --git a/tools/pom.xml b/tools/pom.xml index 30cbb6a5a59c..b3a5ae277124 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-tools_2.10 + spark-tools_2.11 tools diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 21fef3415adc..75fea556eeae 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-unsafe_2.10 + spark-unsafe_2.11 jar Spark Project Unsafe http://spark.apache.org/ diff --git a/yarn/pom.xml b/yarn/pom.xml index a8c122fd40a1..328bb6678db9 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-yarn_2.10 + spark-yarn_2.11 jar Spark Project YARN From de283719980ae78b740e507e4d70c7ebbf6c5f74 Mon Sep 17 00:00:00 2001 From: wangyang Date: Sat, 30 Jan 2016 15:20:57 -0800 Subject: [PATCH 073/131] [SPARK-13100][SQL] improving the performance of stringToDate method in DateTimeUtils.scala In jdk1.7 TimeZone.getTimeZone() is synchronized, so use an instance variable to hold an GMT TimeZone object instead of instantiate it every time. Author: wangyang Closes #10994 from wangyang1992/datetimeUtil. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f18c052b68e3..a159bc6a6141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -55,6 +55,7 @@ object DateTimeUtils { // this is year -17999, calculation: 50 * daysIn400Year final val YearZero = -17999 final val toYearZero = to2001 + 7304850 + final val TimeZoneGMT = TimeZone.getTimeZone("GMT") @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -407,7 +408,7 @@ object DateTimeUtils { segments(2) < 1 || segments(2) > 31) { return None } - val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + val c = Calendar.getInstance(TimeZoneGMT) c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) c.set(Calendar.MILLISECOND, 0) Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) From a1303de0a0e9d0c80327977abf52a79e2aa95e1f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 30 Jan 2016 23:02:49 -0800 Subject: [PATCH 074/131] [SPARK-13070][SQL] Better error message when Parquet schema merging fails Make sure we throw better error messages when Parquet schema merging fails. Author: Cheng Lian Author: Liang-Chi Hsieh Closes #10979 from viirya/schema-merging-failure-message. --- .../apache/spark/sql/types/StructType.scala | 6 ++-- .../datasources/parquet/ParquetRelation.scala | 33 ++++++++++++++++--- .../parquet/ParquetFilterSuite.scala | 15 +++++++++ .../parquet/ParquetSchemaSuite.scala | 30 +++++++++++++++++ 4 files changed, 77 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index da0c92864e9b..c9e7e7fe633b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -424,13 +424,13 @@ object StructType extends AbstractDataType { if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { DecimalType(leftPrecision, leftScale) } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") } else if (leftPrecision != rightPrecision) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision") } else { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"scala $leftScale and $rightScale") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index f87590095d34..1e686d41f41d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -800,12 +800,37 @@ private[sql] object ParquetRelation extends Logging { assumeInt96IsTimestamp = assumeInt96IsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) - footers.map { footer => - ParquetRelation.readSchemaFromFooter(footer, converter) - }.reduceLeftOption(_ merge _).iterator + if (footers.isEmpty) { + Iterator.empty + } else { + var mergedSchema = ParquetRelation.readSchemaFromFooter(footers.head, converter) + footers.tail.foreach { footer => + val schema = ParquetRelation.readSchemaFromFooter(footer, converter) + try { + mergedSchema = mergedSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema of file ${footer.getFile}:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } }.collect() - partiallyMergedSchemas.reduceLeftOption(_ merge _) + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1796b3af0e37..3ded32c45054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -421,6 +421,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // We will remove the temporary metadata when writing Parquet file. val forPathSix = sqlContext.read.parquet(pathSix).schema assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + // sanity test: make sure optional metadata field is not wrongly set. + val pathSeven = s"${dir.getCanonicalPath}/table7" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) + val pathEight = s"${dir.getCanonicalPath}/table8" + (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) + + val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + checkAnswer( + df2, + Row(1, "1")) + + // The fields "a" and "b" exist in both two Parquet files. No metadata is set. + assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) + assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 60fa81b1ab81..d860651d421f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -449,6 +450,35 @@ class ParquetSchemaSuite extends ParquetSchemaTest { }.getMessage.contains("detected conflicting schemas")) } + test("schema merging failure error message") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema of file")) + } + + // test for second merging (after read Parquet schema in parallel done) + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema:")) + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From 0e6d92d042b0a2920d8df5959d5913ba0166a678 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 30 Jan 2016 23:05:29 -0800 Subject: [PATCH 075/131] [SPARK-12689][SQL] Migrate DDL parsing to the newly absorbed parser JIRA: https://issues.apache.org/jira/browse/SPARK-12689 DDLParser processes three commands: createTable, describeTable and refreshTable. This patch migrates the three commands to newly absorbed parser. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10723 from viirya/migrate-ddl-describe. --- project/MimaExcludes.scala | 5 + .../sql/catalyst/parser/ExpressionParser.g | 14 ++ .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 +- .../sql/catalyst/parser/SparkSqlParser.g | 80 +++++++- .../spark/sql/catalyst/CatalystQl.scala | 23 ++- .../org/apache/spark/sql/SQLContext.scala | 5 +- .../apache/spark/sql/execution/SparkQl.scala | 101 ++++++++- .../sql/execution/datasources/DDLParser.scala | 193 ------------------ .../spark/sql/execution/datasources/ddl.scala | 5 - .../sources/CreateTableAsSelectSuite.scala | 7 +- 10 files changed, 208 insertions(+), 229 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3748e07f88aa..8b1a7303fc5b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -200,6 +200,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-12689 Migrate DDL parsing to the newly absorbed parser + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") ) ++ Seq( // SPARK-7799 Add "streaming-akka" project ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 0555a6ba83cb..c162c1a0c578 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -493,6 +493,16 @@ descFuncNames | functionIdentifier ; +//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. +looseIdentifier + : + Identifier + | looseNonReserved -> Identifier[$looseNonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + identifier : Identifier @@ -516,6 +526,10 @@ principalIdentifier | QuotedIdentifier ; +looseNonReserved + : nonReserved | KW_FROM | KW_TO + ; + //The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved //Non reserved keywords are basically the keywords that can be used as identifiers. //All the KW_* are automatically not only keywords, but also reserved keywords. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 4374cd7ef720..e930caa291d4 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -324,6 +324,8 @@ KW_ISOLATION: 'ISOLATION'; KW_LEVEL: 'LEVEL'; KW_SNAPSHOT: 'SNAPSHOT'; KW_AUTOCOMMIT: 'AUTOCOMMIT'; +KW_REFRESH: 'REFRESH'; +KW_OPTIONS: 'OPTIONS'; KW_WEEK: 'WEEK'|'WEEKS'; KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; @@ -470,7 +472,7 @@ Identifier fragment QuotedIdentifier : - '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); } ; WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 35bef00351d7..6591f6b0f56c 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -142,6 +142,7 @@ TOK_UNIONTYPE; TOK_COLTYPELIST; TOK_CREATEDATABASE; TOK_CREATETABLE; +TOK_CREATETABLEUSING; TOK_TRUNCATETABLE; TOK_CREATEINDEX; TOK_CREATEINDEX_INDEXTBLNAME; @@ -371,6 +372,10 @@ TOK_TXN_READ_WRITE; TOK_COMMIT; TOK_ROLLBACK; TOK_SET_AUTOCOMMIT; +TOK_REFRESHTABLE; +TOK_TABLEPROVIDER; +TOK_TABLEOPTIONS; +TOK_TABLEOPTION; TOK_CACHETABLE; TOK_UNCACHETABLE; TOK_CLEARCACHE; @@ -660,6 +665,12 @@ import java.util.HashMap; } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { + if (input.length() > 0) { + if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') { + // When column name is backquoted, we don't care about excluded chars. + return false; + } + } for(char c : excludedCharForColumnName) { if(input.indexOf(c)>-1) { return true; @@ -781,6 +792,7 @@ ddlStatement | truncateTableStatement | alterStatement | descStatement + | refreshStatement | showStatement | metastoreCheck | createViewStatement @@ -907,12 +919,31 @@ createTableStatement @init { pushMsg("create table statement", state); } @after { popMsg(state); } : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName - ( like=KW_LIKE likeName=tableName + ( + like=KW_LIKE likeName=tableName tableRowFormat? tableFileFormat? tableLocation? tablePropertiesPrefixed? + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + ) + | + tableProvider + tableOpts? + (KW_AS selectStatementWithCTE)? + -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? + tableProvider + tableOpts? + selectStatementWithCTE? + ) | (LPAREN columnNameTypeList RPAREN)? + (p=tableProvider?) + tableOpts? tableComment? tablePartition? tableBuckets? @@ -922,8 +953,15 @@ createTableStatement tableLocation? tablePropertiesPrefixed? (KW_AS selectStatementWithCTE)? - ) - -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + -> {p != null}? + ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? + columnNameTypeList? + $p + tableOpts? + selectStatementWithCTE? + ) + -> + ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? ^(TOK_LIKETABLE $likeName?) columnNameTypeList? tableComment? @@ -935,7 +973,8 @@ createTableStatement tableLocation? tablePropertiesPrefixed? selectStatementWithCTE? - ) + ) + ) ; truncateTableStatement @@ -1379,6 +1418,13 @@ tabPartColTypeExpr : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) ; +refreshStatement +@init { pushMsg("refresh statement", state); } +@after { popMsg(state); } + : + KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName) + ; + descStatement @init { pushMsg("describe statement", state); } @after { popMsg(state); } @@ -1774,6 +1820,30 @@ showStmtIdentifier | StringLiteral ; +tableProvider +@init { pushMsg("table's provider", state); } +@after { popMsg(state); } + : + KW_USING Identifier (DOT Identifier)* + -> ^(TOK_TABLEPROVIDER Identifier+) + ; + +optionKeyValue +@init { pushMsg("table's option specification", state); } +@after { popMsg(state); } + : + (looseIdentifier (DOT looseIdentifier)*) StringLiteral + -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral) + ; + +tableOpts +@init { pushMsg("table's options", state); } +@after { popMsg(state); } + : + KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN + -> ^(TOK_TABLEOPTIONS optionKeyValue+) + ; + tableComment @init { pushMsg("table's comment", state); } @after { popMsg(state); } @@ -2132,7 +2202,7 @@ structType mapType @init { pushMsg("map type", state); } @after { popMsg(state); } - : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN -> ^(TOK_MAP $left $right) ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 536c292ab7f3..7ce2407913ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -140,6 +140,7 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends case Token("TOK_BOOLEAN", Nil) => BooleanType case Token("TOK_STRING", Nil) => StringType case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType case Token("TOK_DOUBLE", Nil) => DoubleType case Token("TOK_DATE", Nil) => DateType @@ -156,9 +157,10 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends protected def nodeToStructField(node: ASTNode): StructField = node match { case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) + StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) => + val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build() + StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta) case _ => noParseRule("StructField", node) } @@ -222,15 +224,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Nil => ShowFunctions(None, None) case Token(name, Nil) :: Nil => - ShowFunctions(None, Some(unquoteString(name))) + ShowFunctions(None, Some(unquoteString(cleanIdentifier(name)))) case Token(db, Nil) :: Token(name, Nil) :: Nil => - ShowFunctions(Some(unquoteString(db)), Some(unquoteString(name))) + ShowFunctions(Some(unquoteString(cleanIdentifier(db))), + Some(unquoteString(cleanIdentifier(name)))) case _ => noParseRule("SHOW FUNCTIONS", node) } case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => - DescribeFunction(functionName, isExtended.nonEmpty) + DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty) case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = @@ -611,7 +614,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C noParseRule("Select", node) } - protected val escapedIdentifier = "`([^`]+)`".r + protected val escapedIdentifier = "`(.+)`".r protected val doubleQuotedString = "\"([^\"]+)\"".r protected val singleQuotedString = "'([^']+)'".r @@ -655,7 +658,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nodeToExpr(qualifier) match { case UnresolvedAttribute(nameParts) => UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) + case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) } /* Stars (*) */ @@ -663,7 +666,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => - UnresolvedStar(Some(target.map(_.text))) + UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text)))) /* Aggregate Functions */ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => @@ -971,7 +974,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node - val alias = getClause("TOK_TABALIAS", clauses).children.head.text + val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text) val generator = clauses.head match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index be28df3a5155..ef993c3edae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -206,10 +206,7 @@ class SQLContext private[sql]( @transient protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) - @transient - protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) - - protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index a5bd8ee42dec..4174e27e9c8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -16,11 +16,14 @@ */ package org.apache.spark.sql.execution +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.types.StructType private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ @@ -55,6 +58,86 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + case Token("TOK_REFRESHTABLE", nameParts :: Nil) => + val tableIdent = extractTableIdent(nameParts) + RefreshTable(tableIdent) + + case Token("TOK_CREATETABLEUSING", createTableArgs) => + val Seq( + temp, + allowExisting, + Some(tabName), + tableCols, + Some(Token("TOK_TABLEPROVIDER", providerNameParts)), + tableOpts, + tableAs) = getClauses(Seq( + "TEMPORARY", + "TOK_IFNOTEXISTS", + "TOK_TABNAME", "TOK_TABCOLLIST", + "TOK_TABLEPROVIDER", + "TOK_TABLEOPTIONS", + "TOK_QUERY"), createTableArgs) + + val tableIdent: TableIdentifier = extractTableIdent(tabName) + + val columns = tableCols.map { + case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField)) + } + + val provider = providerNameParts.map { + case Token(name, Nil) => name + }.mkString(".") + + val options: Map[String, String] = tableOpts.toSeq.flatMap { + case Token("TOK_TABLEOPTIONS", options) => + options.map { + case Token("TOK_TABLEOPTION", keysAndValue) => + val key = keysAndValue.init.map(_.text).mkString(".") + val value = unquoteString(keysAndValue.last.text) + (key, value) + } + }.toMap + + val asClause = tableAs.map(nodeToPlan(_)) + + if (temp.isDefined && allowExisting.isDefined) { + throw new AnalysisException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + if (asClause.isDefined) { + if (columns.isDefined) { + throw new AnalysisException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + CreateTableUsingAsSelect(tableIdent, + provider, + temp.isDefined, + Array.empty[String], + bucketSpec = None, + mode, + options, + asClause.get) + } else { + CreateTableUsing( + tableIdent, + columns, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => SetDatabaseCommand(cleanIdentifier(database)) @@ -68,26 +151,30 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly nodeToDescribeFallback(node) } else { tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) => nameParts match { - case Token(".", dbName :: tableName :: Nil) => + case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this // issue. - val tableIdent = extractTableIdent(nameParts) + val tableIdent = TableIdentifier( + cleanIdentifier(tableName), Some(cleanIdentifier(dbName))) datasources.DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => + case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil => // It is describing a column with the format like "describe db.table column". nodeToDescribeFallback(node) - case tableName => + case tableName :: Nil => // It is describing a table with the format like "describe table". datasources.DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.text), None), + UnresolvedRelation(TableIdentifier(cleanIdentifier(tableName.text)), None), isExtended = extended.isDefined) + case _ => + nodeToDescribeFallback(node) } // All other cases. - case _ => nodeToDescribeFallback(node) + case _ => + nodeToDescribeFallback(node) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala deleted file mode 100644 index f4766b037027..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ /dev/null @@ -1,193 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.datasources - -import scala.language.implicitConversions -import scala.util.matching.Regex - -import org.apache.spark.Logging -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ - -/** - * A parser for foreign DDL commands. - */ -class DDLParser(fallback: => ParserInterface) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = { - fallback.parseTableIdentifier(sql) - } - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parsePlan(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => fallback.parsePlan(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = { - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = fallback.parsePlan(query.get) - CreateTableUsingAsSelect(tableIdent, - provider, - temp.isDefined, - Array.empty[String], - bucketSpec = None, - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableIdent, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - } - - // This is the same as tableIdentifier in SqlParser. - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { - case e ~ tableIdent => - DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> tableIdentifier ^^ { - case tableIndet => - RefreshTable(tableIndet) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c3603936dfd2..1554209be989 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -169,8 +169,3 @@ class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] override def -(key: String): Map[String, String] = baseMap - key.toLowerCase } - -/** - * The exception thrown from the DDL parser. - */ -class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 6fc9febe4970..cb88a1c83c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,7 +22,6 @@ import java.io.{File, IOException} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -105,7 +104,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -156,7 +155,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -173,7 +172,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("a CTAS statement with column definitions is not allowed") { - intercept[DDLException]{ + intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) From 5a8b978fabb60aa178274f86432c63680c8b351a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 31 Jan 2016 13:56:13 -0800 Subject: [PATCH 076/131] [SPARK-13049] Add First/last with ignore nulls to functions.scala This PR adds the ability to specify the ```ignoreNulls``` option to the functions dsl, e.g: ```df.select($"id", last($"value", ignoreNulls = true).over(Window.partitionBy($"id").orderBy($"other"))``` This PR is some where between a bug fix (see the JIRA) and a new feature. I am not sure if we should backport to 1.6. cc yhuai Author: Herman van Hovell Closes #10957 from hvanhovell/SPARK-13049. --- python/pyspark/sql/functions.py | 26 +++- python/pyspark/sql/tests.py | 10 ++ .../org/apache/spark/sql/functions.scala | 118 ++++++++++++++---- .../spark/sql/DataFrameWindowSuite.scala | 32 +++++ 4 files changed, 157 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719eca8f5559..0d5708526701 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -81,8 +81,6 @@ def _(): 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', 'count': 'Aggregate function: returns the number of items in a group.', 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', @@ -278,6 +276,18 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.3) +def first(col, ignorenulls=False): + """Aggregate function: returns the first value in a group. + + The function by default returns the first values it sees. It will return the first non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -310,6 +320,18 @@ def isnull(col): return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.3) +def last(col, ignorenulls=False): + """Aggregate function: returns the last value in a group. + + The function by default returns the last values it sees. It will return the last non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 410efbafe079..e30aa0a79692 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -641,6 +641,16 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.sqlCtx.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a27466176a2..b970eee4e31a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -349,19 +349,51 @@ object functions extends LegacyFunctions { } /** - * Aggregate function: returns the first value in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def first(e: Column): Column = withAggregateFunction { new First(e.expr) } - - /** - * Aggregate function: returns the first value of a column in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new First(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(columnName: String, ignoreNulls: Boolean): Column = { + first(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def first(e: Column): Column = first(e, ignoreNulls = false) + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(columnName: String): Column = first(Column(columnName)) /** @@ -381,20 +413,52 @@ object functions extends LegacyFunctions { def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) /** - * Aggregate function: returns the last value in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } - - /** - * Aggregate function: returns the last value of the column in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def last(columnName: String): Column = last(Column(columnName)) + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new Last(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(columnName: String, ignoreNulls: Boolean): Column = { + last(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(e: Column): Column = last(e, ignoreNulls = false) + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) /** * Aggregate function: returns the maximum value of the expression in a group. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 09a56f6f3ae2..d38842c3c0cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -312,4 +312,36 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 3, null, null), Row("b", 2, null, null))) } + + test("last/first with ignoreNulls") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, nullStr), + ("b", 1, nullStr), + ("b", 2, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order") + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, null, null, null, null, null, null), + Row("a", 1, null, null, "x", "x", "x", "x"), + Row("a", 2, null, null, "x", "y", "y", "y"), + Row("a", 3, null, null, "x", "z", "z", "z"), + Row("a", 4, null, null, "x", null, null, "z"), + Row("b", 1, null, null, null, null, null, null), + Row("b", 2, null, null, null, null, null, null))) + } } From c1da4d421ab78772ffa52ad46e5bdfb4e5268f47 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 Jan 2016 22:43:03 -0800 Subject: [PATCH 077/131] [SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression The current implementation is sub-optimal: * If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable. * If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others. This PR improves this by making the null check elimination more fine-grained. Author: Wenchen Fan Closes #10987 from cloud-fan/null-check. --- .../sql/catalyst/expressions/Expression.scala | 104 ++++++++++-------- .../expressions/codegen/CodeGenerator.scala | 32 +++++- .../spark/sql/catalyst/expressions/misc.scala | 16 +-- 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index db17ba7c84ff..353fb92581d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * As an example, the following does a boolean inversion (i.e. NOT). * {{{ @@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * @param f function that accepts the non-null evaluation result name of child and returns Java * code to compute the output. @@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: String => String): String = { - val eval = child.gen(ctx) + val childGen = child.gen(ctx) + val resultCode = f(childGen.value) + if (nullable) { - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; + val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${eval.isNull}) { - ${f(eval.value)} - } + $nullSafeEval """ } else { ev.isNull = "false" - eval.code + s""" + s""" + ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${f(eval.value)} + $resultCode """ } } @@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(eval1.value, eval2.value) + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + val resultCode = f(leftGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { + rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } + $nullSafeEval """ - } else { ev.isNull = "false" s""" - ${eval1.code} - ${eval2.code} + ${leftGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ @@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression { /** * Default behavior of evaluation according to the default nullability of TernaryExpression. - * If subclass of BinaryExpression override nullable, probably should also override this. + * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { val exprs = children @@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression { sys.error(s"BinaryExpressions must override either eval or nullSafeEval") /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f accepts two variable names and returns Java code to compute the output. + * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( ctx: CodegenContext, @@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression { } /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f function that accepts the 2 non-null evaluation result names of children + * @param f function that accepts the 3 non-null evaluation result names of children * and returns Java code to compute the output. */ protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): String = { - val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) + val leftGen = children(0).gen(ctx) + val midGen = children(1).gen(ctx) + val rightGen = children(2).gen(ctx) + val resultCode = f(leftGen.value, midGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { + rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + } + s""" - ${evals(0).code} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode - } - } - } + $nullSafeEval """ } else { ev.isNull = "false" s""" - ${evals(0).code} - ${evals(1).code} - ${evals(2).code} + ${leftGen.code} + ${midGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 21f9198073d7..a30aba16170a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,17 +402,37 @@ class CodegenContext { } /** - * Generates code for greater of two expressions. - * - * @param dataType data type of the expressions - * @param c1 name of the variable of expression 1's output - * @param c2 name of the variable of expression 2's output - */ + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code to do null safe execution, i.e. only execute the code when the input is not + * null by adding null check if necessary. + * + * @param nullable used to decide whether we should add null check or not. + * @param isNull the code to check if the input is null. + * @param execute the code that should only be executed when the input is not null. + */ + def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execute + } + """ + } else { + "\n" + execute + } + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8480c3f9a12f..36e1fa1176d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.gen(ctx) - childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") @@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression """ } - private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { - if (nullable) { - s""" - if (!$isNull) { - $execution - } - """ - } else { - "\n" + execution - } - } - private def nullSafeElementHash( input: String, index: String, @@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ctx: CodegenContext): String = { val element = ctx.freshName("element") - generateNullCheck(nullable, s"$input.isNullAt($index)") { + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} From 6075573a93176ee8c071888e4525043d9e73b061 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 1 Feb 2016 11:02:17 -0800 Subject: [PATCH 078/131] [SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream Add a local property to indicate if checkpointing all RDDs that are marked with the checkpoint flag, and enable it in Streaming Author: Shixiong Zhu Closes #10934 from zsxwing/recursive-checkpoint. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 +++++ .../org/apache/spark/CheckpointSuite.scala | 21 ++++++ .../streaming/scheduler/JobGenerator.scala | 5 ++ .../streaming/scheduler/JobScheduler.scala | 7 +- .../spark/streaming/CheckpointSuite.scala | 69 +++++++++++++++++++ 5 files changed, 119 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index be47172581b7..e8157cf4ebe7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1542,6 +1542,15 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default, + // we stop as soon as we find the first such RDD, an optimization that allows us to write + // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both + // an RDD and its parent in every batch, in which case the parent may never be checkpointed + // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). + private val checkpointAllMarkedAncestors = + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) + .map(_.toBoolean).getOrElse(false) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -1585,6 +1594,13 @@ abstract class RDD[T: ClassTag]( if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { + if (checkpointAllMarkedAncestors) { + // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint + // them in parallel. + // Checkpoint parents first because our lineage will be truncated after we + // checkpoint ourselves + dependencies.foreach(_.rdd.doCheckpoint()) + } checkpointData.get.checkpoint() } else { dependencies.foreach(_.rdd.doCheckpoint()) @@ -1704,6 +1720,9 @@ abstract class RDD[T: ClassTag]( */ object RDD { + private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS = + "spark.checkpoint.checkpointAllMarkedAncestors" + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 390764ba242f..ce35856dce3f 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -512,6 +512,27 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } + + runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean => + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true) + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false) + } + + private def testCheckpointAllMarkedAncestors( + reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString) + try { + val rdd1 = sc.parallelize(1 to 10) + checkpoint(rdd1, reliableCheckpoint) + val rdd2 = rdd1.map(_ + 1) + checkpoint(rdd2, reliableCheckpoint) + rdd2.count() + assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors) + assert(rdd2.isCheckpointed === true) + } finally { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null) + } + } } /** RDD partition that has large serialized size. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index a5a01e77639c..a3ad5eaa40ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} @@ -243,6 +244,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Example: BlockRDDs are created in this thread, and it needs to access BlockManager // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. SparkEnv.set(ssc.env) + + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") Try { jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch graph.generateJobs(time) // generate jobs using allocated block diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 9535c8e5b768..3fed3d88354c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -23,10 +23,10 @@ import scala.collection.JavaConverters._ import scala.util.Failure import org.apache.spark.Logging -import org.apache.spark.rdd.PairRDDFunctions +import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -210,6 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { s"""Streaming job from $batchLinkText""") ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") // We need to assign `eventLoop` to a temp variable. Otherwise, because // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 4a6b91fbc745..786703eb9a84 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester checkpointWriter.stop() } + test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") { + // In this test, there are two updateStateByKey operators. The RDD DAG is as follows: + // + // batch 1 batch 2 batch 3 ... + // + // 1) input rdd input rdd input rdd + // | | | + // v v v + // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 3) map rdd --- map rdd --- map rdd ... + // | | | + // v v v + // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 5) map rdd --- map rdd --- map rdd ... + // + // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to + // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second + // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first + // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3) + // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken + // and the RDD chain will grow infinitely and cause StackOverflow. + // + // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing + // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break + // connections between layer 2 and layer 3) + ssc = new StreamingContext(master, framework, batchDuration) + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + @volatile var shouldCheckpointAllMarkedRDDs = false + @volatile var rddsCheckpointed = false + inputDStream.map(i => (i, i)) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .foreachRDD { rdd => + /** + * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors. + */ + def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = { + val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList + if (rdd.checkpointData.isDefined) { + rdd :: markedRDDs + } else { + markedRDDs + } + } + + shouldCheckpointAllMarkedRDDs = + Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)). + map(_.toBoolean).getOrElse(false) + + val stateRDDs = findAllMarkedRDDs(rdd) + rdd.count() + // Check the two state RDDs are both checkpointed + rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed) + } + ssc.start() + batchCounter.waitUntilBatchesCompleted(1, 10000) + assert(shouldCheckpointAllMarkedRDDs === true) + assert(rddsCheckpointed === true) + } + /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. From 33c8a490f7f64320c53530a57bd8d34916e3607c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 1 Feb 2016 11:22:02 -0800 Subject: [PATCH 079/131] [SPARK-12989][SQL] Delaying Alias Cleanup after ExtractWindowExpressions JIRA: https://issues.apache.org/jira/browse/SPARK-12989 In the rule `ExtractWindowExpressions`, we simply replace alias by the corresponding attribute. However, this will cause an issue exposed by the following case: ```scala val data = Seq(("a", "b", "c", 3), ("c", "b", "a", 3)).toDF("A", "B", "C", "num") .withColumn("Data", struct("A", "B", "C")) .drop("A") .drop("B") .drop("C") val winSpec = Window.partitionBy("Data.A", "Data.B").orderBy($"num".desc) data.select($"*", max("num").over(winSpec) as "max").explain(true) ``` In this case, both `Data.A` and `Data.B` are `alias` in `WindowSpecDefinition`. If we replace these alias expression by their alias names, we are unable to know what they are since they will not be put in `missingExpr` too. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10963 from gatorsmile/seletStarAfterColDrop. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++-- .../org/apache/spark/sql/DataFrameWindowSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5fe700ee0067..ee60fca1ad4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -883,12 +883,13 @@ class Analyzer( if (missingExpr.nonEmpty) { extractedExprBuffer += ne } - ne.toAttribute + // alias will be cleaned in the rule CleanupAliases + ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with - // an interal column name, e.g. "_w0"). + // an internal column name, e.g. "_w0"). val withName = Alias(e, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index d38842c3c0cf..2bcbb1983f7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -344,4 +344,14 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 1, null, null, null, null, null, null), Row("b", 2, null, null, null, null, null, null))) } + + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { + val src = Seq((0, 3, 5)).toDF("a", "b", "c") + .withColumn("Data", struct("a", "b")) + .drop("a") + .drop("b") + val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc) + val df = src.select($"*", max("c").over(winSpec) as "max") + checkAnswer(df, Row(5, Row(0, 3), 5)) + } } From 8f26eb5ef6853a6666d7d9481b333de70bc501ed Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 1 Feb 2016 11:57:13 -0800 Subject: [PATCH 080/131] [SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences JIRA: https://issues.apache.org/jira/browse/SPARK-12705 **Scope:** This PR is a general fix for sorting reference resolution when the child's `outputSet` does not have the order-by attributes (called, *missing attributes*): - UnaryNode support is limited to `Project`, `Window`, `Aggregate`, `Distinct`, `Filter`, `RepartitionByExpression`. - We will not try to resolve the missing references inside a subquery, unless the outputSet of this subquery contains it. **General Reference Resolution Rules:** - Jump over the nodes with the following types: `Distinct`, `Filter`, `RepartitionByExpression`. Do not need to add missing attributes. The reason is their `outputSet` is decided by their `inputSet`, which is the `outputSet` of their children. - Group-by expressions in `Aggregate`: missing order-by attributes are not allowed to be added into group-by expressions since it will change the query result. Thus, in RDBMS, it is not allowed. - Aggregate expressions in `Aggregate`: if the group-by expressions in `Aggregate` contains the missing attributes but aggregate expressions do not have it, just add them into the aggregate expressions. This can resolve the analysisExceptions thrown by the three TCPDS queries. - `Project` and `Window` are special. We just need to add the missing attributes to their `projectList`. **Implementation:** 1. Traverse the whole tree in a pre-order manner to find all the resolvable missing order-by attributes. 2. Traverse the whole tree in a post-order manner to add the found missing order-by attributes to the node if their `inputSet` contains the attributes. 3. If the origins of the missing order-by attributes are different nodes, each pass only resolves the missing attributes that are from the same node. **Risk:** Low. This rule will be trigger iff ```!s.resolved && child.resolved``` is true. Thus, very few cases are affected. Author: gatorsmile Closes #10678 from gatorsmile/sortWindows. --- .../sql/catalyst/analysis/Analyzer.scala | 101 ++++++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 83 ++++++++++++++ .../sql/catalyst/analysis/TestRelations.scala | 6 ++ .../apache/spark/sql/DataFrameJoinSuite.scala | 16 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 84 ++++++++++++++- 6 files changed, 274 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ee60fca1ad4f..a983dc1cdfeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -452,7 +453,7 @@ class Analyzer( i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on grandchild + // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) @@ -533,38 +534,96 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(ordering, global, p @ Project(projectList, child)) - if !s.resolved && p.resolved => - val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) + // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, child: Aggregate) => sa - // If this rule was not a no-op, return the transformed plan, otherwise return the original. - if (missing.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(p.output, - Sort(newOrdering, global, - Project(projectList ++ missing, child))) - } else { - logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") + case s @ Sort(_, _, child) if !s.resolved && child.resolved => + val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) + + if (missingResolvableAttrs.isEmpty) { + val unresolvableAttrs = s.order.filterNot(_.resolved) + logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") s // Nothing we can do here. Return original plan. + } else { + // Add the missing attributes into projectList of Project/Window or + // aggregateExpressions of Aggregate, if they are in the inputSet + // but not in the outputSet of the plan. + val newChild = child transformUp { + case p: Project => + p.copy(projectList = p.projectList ++ + missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) + case w: Window => + w.copy(projectList = w.projectList ++ + missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) + case a: Aggregate => + val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) + val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) + val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case o => o + } + + // Add missing attributes and then project them away after the sort. + Project(child.output, + Sort(newOrdering, s.global, newChild)) } } /** - * Given a child and a grandchild that are present beneath a sort operator, try to resolve - * the sort ordering and returns it with a list of attributes that are missing from the - * child but are present in the grandchild. + * Traverse the tree until resolving the sorting attributes + * Return all the resolvable missing sorting attributes + */ + @tailrec + private def collectResolvableMissingAttrs( + ordering: Seq[SortOrder], + plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + plan match { + // Only Windows and Project have projectList-like attribute. + case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) { + (newOrdering, missingAttrs) + } else { + collectResolvableMissingAttrs(ordering, un.child) + } + case a: Aggregate => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) + // For Aggregate, all the order by columns must be specified in group by clauses + if (missingAttrs.nonEmpty && + missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { + (newOrdering, missingAttrs) + } else { + // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes + (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + // Jump over the following UnaryNode types + // The output of these types is the same as their child's output + case _: Distinct | + _: Filter | + _: RepartitionByExpression => + collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) + // If hitting the other unsupported operators, we are unable to resolve it. + case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + } + + /** + * Try to resolve the sort ordering and returns it with a list of attributes that are missing + * from the plan but are present in the child. */ - def resolveAndFindMissing( + private def resolveAndFindMissing( ordering: Seq[SortOrder], - child: LogicalPlan, - grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) + plan: LogicalPlan, + child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + val newOrdering = resolveSortOrders(ordering, child, throws = false) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. - val missingInProject = requiredAttributes -- child.output + val missingInProject = requiredAttributes -- plan.outputSet // It is important to return the new SortOrders here, instead of waiting for the standard // resolving process as adding attributes to the project below can actually introduce // ambiguity that was not present before. @@ -719,7 +778,7 @@ class Analyzer( } } - protected def containsAggregate(condition: Expression): Boolean = { + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1938bce02a17..ebf885a8fe48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } + test("resolve sort references - filter/limit") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + + // Case 1: one missing attribute is in the leaf node and another is in the unary node + val plan1 = testRelation2 + .where('a > "str").select('a, 'b) + .where('b > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected1 = testRelation2 + .where(a > "str").select(a, b, c) + .where(b > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a, b).select(a) + checkAnalysis(plan1, expected1) + + // Case 2: all the missing attributes are in the leaf node + val plan2 = testRelation2 + .where('a > "str").select('a) + .where('a > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected2 = testRelation2 + .where(a > "str").select(a, b, c) + .where(a > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan2, expected2) + } + + test("resolve sort references - join") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val h = testRelation3.output(3) + + // Case: join itself can resolve all the missing attributes + val plan = testRelation2.join(testRelation3) + .where('a > "str").select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected = testRelation2.join(testRelation3) + .where(a > "str").select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b) + checkAnalysis(plan, expected) + } + + test("resolve sort references - aggregate") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val alias_a3 = count(a).as("a3") + val alias_b = b.as("aggOrder") + + // Case 1: when the child of Sort is not Aggregate, + // the sort reference is handled by the rule ResolveSortReferences + val plan1 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .select('a, 'c, 'a3) + .orderBy('b.asc) + + val expected1 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, b) + .select(a, c, alias_a3.toAttribute, b) + .orderBy(b.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan1, expected1) + + // Case 2: when the child of Sort is Aggregate, + // the sort reference is handled by the rule ResolveAggregateFunctions + val plan2 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .orderBy('b.asc) + + val expected2 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, alias_b) + .orderBy(alias_b.toAttribute.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan2, expected2) + } + test("resolve relations") { assertAnalysisError( UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index bc07b609a341..3741a6ba95a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -31,6 +31,12 @@ object TestRelations { AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) + val testRelation3 = LocalRelation( + AttributeReference("e", ShortType)(), + AttributeReference("f", StringType)(), + AttributeReference("g", DoubleType)(), + AttributeReference("h", DecimalType(10, 2))()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c17be8ace928..a5e5f156423c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - sorted columns not in join's outputSet") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3) + + checkAnswer( + df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") + .orderBy('str_sort.asc, 'str.asc), + Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + + checkAnswer( + df2.join(df3, $"df2.int" === $"df3.int", "inner") + .select($"df2.int", $"df3.int").orderBy($"df2.str".desc), + Row(5, 5) :: Row(1, 1) :: Nil) + } + test("join - join using multiple columns and specifying join type") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4ff99bdf2937..c02133ffc854 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -954,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(expected === actual) } + test("Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-9323: DataFrame.orderBy should support nested column name") { val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1ada2e325bda..6048b8f5a399 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), (2 to 6).map(i => Row(i))) } - test("window function: udaf with aggregate expressin") { + test("window function: udaf with aggregate expression") { val data = Seq( WindowData(1, "a", 5), WindowData(2, "a", 6), @@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("d", 1), + ("c", 2) + ).map(i => Row(i._1, i._2))) + } + + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") From da9146c91a33577ff81378ca7e7c38a4b1917876 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 1 Feb 2016 12:02:06 -0800 Subject: [PATCH 081/131] [DOCS] Fix the jar location of datanucleus in sql-programming-guid.md ISTM `lib` is better because `datanucleus` jars are located in `lib` for release builds. Author: Takeshi YAMAMURO Closes #10901 from maropu/DocFix. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fddc51379406..550a40010e82 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1695,7 +1695,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running -the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. From 711ce048a285403241bbc9eaabffc1314162e89c Mon Sep 17 00:00:00 2001 From: Lewuathe Date: Mon, 1 Feb 2016 12:21:21 -0800 Subject: [PATCH 082/131] [ML][MINOR] Invalid MulticlassClassification reference in ml-guide In [ml-guide](https://spark.apache.org/docs/latest/ml-guide.html#example-model-selection-via-cross-validation), there is invalid reference to `MulticlassClassificationEvaluator` apidoc. https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator Author: Lewuathe Closes #10996 from Lewuathe/fix-typo-in-ml-guide. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5aafd53b584e..f8279262e673 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -627,7 +627,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName` method in each of these evaluators. From 51b03b71ffc390e67b32936efba61e614a8b0d86 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Mon, 1 Feb 2016 12:45:02 -0800 Subject: [PATCH 083/131] [SPARK-12463][SPARK-12464][SPARK-12465][SPARK-10647][MESOS] Fix zookeeper dir with mesos conf and add docs. Fix zookeeper dir configuration used in cluster mode, and also add documentation around these settings. Author: Timothy Chen Closes #10057 from tnachen/fix_mesos_dir. --- .../deploy/mesos/MesosClusterDispatcher.scala | 6 ++--- .../mesos/MesosClusterPersistenceEngine.scala | 4 ++-- docs/configuration.md | 23 +++++++++++++++++++ docs/running-on-mesos.md | 5 +++- docs/spark-standalone.md | 23 ++++--------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 66e1e645007a..9b31497adfb1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -50,7 +50,7 @@ private[mesos] class MesosClusterDispatcher( extends Logging { private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) - private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) private val engineFactory = recoveryMode match { @@ -98,8 +98,8 @@ private[mesos] object MesosClusterDispatcher extends Logging { conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => - conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") - conf.set("spark.mesos.deploy.zookeeper.url", z) + conf.set("spark.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.deploy.zookeeper.url", z) } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index e0c547dce6d0..092d9e418253 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -53,9 +53,9 @@ private[spark] trait MesosClusterPersistenceEngine { * all of them reuses the same connection pool. */ private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) - extends MesosClusterPersistenceEngineFactory(conf) { + extends MesosClusterPersistenceEngineFactory(conf) with Logging { - lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + lazy val zk = SparkCuratorUtil.newClient(conf) def createEngine(path: String): MesosClusterPersistenceEngine = { new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) diff --git a/docs/configuration.md b/docs/configuration.md index 74a8fb5d35a6..93b399d819cc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1585,6 +1585,29 @@ Apart from these, the following properties are also available, and may be useful +#### Deploy + + + + + + + + + + + + + + + + + + +
        Property NameDefaultMeaniing
        spark.deploy.recoveryModeNONEThe recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches. + This is only applicable for cluster mode when running with Standalone or Mesos.
        spark.deploy.zookeeper.urlNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.
        spark.deploy.zookeeper.dirNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.
        + + #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ed720f1039f9..0ef1ccb36e11 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -153,7 +153,10 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. + +The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]. From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fe9ec3542b2..3de72bc016dd 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -112,8 +112,8 @@ You can optionally configure the cluster further by setting environment variable SPARK_LOCAL_DIRS - Directory to use for "scratch" space in Spark, including map output files and RDDs that get - stored on disk. This should be on a fast, local disk in your system. It can also be a + Directory to use for "scratch" space in Spark, including map output files and RDDs that get + stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. @@ -341,23 +341,8 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** -In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration: - - - - - - - - - - - - - - - -
        System propertyMeaning
        spark.deploy.recoveryModeSet to ZOOKEEPER to enable standby Master recovery mode (default: NONE).
        spark.deploy.zookeeper.urlThe ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).
        spark.deploy.zookeeper.dirThe directory in ZooKeeper to store recovery state (default: /spark).
        +In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). From a41b68b954ba47284a1df312f0aaea29b0721b0a Mon Sep 17 00:00:00 2001 From: Nilanjan Raychaudhuri Date: Mon, 1 Feb 2016 13:33:24 -0800 Subject: [PATCH 084/131] [SPARK-12265][MESOS] Spark calls System.exit inside driver instead of throwing exception This takes over #10729 and makes sure that `spark-shell` fails with a proper error message. There is a slight behavioral change: before this change `spark-shell` would exit, while now the REPL is still there, but `sc` and `sqlContext` are not defined and the error is visible to the user. Author: Nilanjan Raychaudhuri Author: Iulian Dragos Closes #10921 from dragos/pr/10729. --- .../cluster/mesos/MesosClusterScheduler.scala | 1 + .../cluster/mesos/MesosSchedulerBackend.scala | 1 + .../cluster/mesos/MesosSchedulerUtils.scala | 21 +++++++++++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 05fda0fded7f..e77d77208ccb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -573,6 +573,7 @@ private[spark] class MesosClusterScheduler( override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} override def error(driver: SchedulerDriver, error: String): Unit = { logError("Error received: " + error) + markErr() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index eaf0cb06d6c7..a8bf79a78cf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -375,6 +375,7 @@ private[spark] class MesosSchedulerBackend( override def error(d: SchedulerDriver, message: String) { inClassLoader() { logError("Mesos error: " + message) + markErr() scheduler.error(message) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 010caff3e39b..f9f5da9bc8df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -106,28 +106,37 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.await() return } + @volatile + var error: Option[Exception] = None + // We create a new thread that will block inside `mesosDriver.run` + // until the scheduler exists new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { setDaemon(true) - override def run() { - mesosDriver = newDriver try { + mesosDriver = newDriver val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { - System.exit(1) + error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) + markErr() } } catch { case e: Exception => { logError("driver.run() failed", e) - System.exit(1) + error = Some(e) + markErr() } } } }.start() registerLatch.await() + + // propagate any error to the calling thread. This ensures that SparkContext creation fails + // without leaving a broken context that won't be able to schedule any tasks + error.foreach(throw _) } } @@ -144,6 +153,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.countDown() } + protected def markErr(): Unit = { + registerLatch.countDown() + } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { val builder = Resource.newBuilder() .setName(name) From c9b89a0a0921ce3d52864afd4feb7f37b90f7b46 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Mon, 1 Feb 2016 13:38:38 -0800 Subject: [PATCH 085/131] =?UTF-8?q?[SPARK-12979][MESOS]=20Don=E2=80=99t=20?= =?UTF-8?q?resolve=20paths=20on=20the=20local=20file=20system=20in=20Mesos?= =?UTF-8?q?=20scheduler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The driver filesystem is likely different from where the executors will run, so resolving paths (and symlinks, etc.) will lead to invalid paths on executors. Author: Iulian Dragos Closes #10923 from dragos/issue/canonical-paths. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 58c30e7d9788..2f095b86c69e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -179,7 +179,7 @@ private[spark] class CoarseMesosSchedulerBackend( .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) if (uri.isEmpty) { - val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath + val runScript = new File(executorSparkHome, "./bin/spark-class").getPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index e77d77208ccb..8cda4ff0eb3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -394,7 +394,7 @@ private[spark] class MesosClusterScheduler( .getOrElse { throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } - val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getPath // Sandbox points to the current directory by default with Mesos. (cmdExecutable, ".") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a8bf79a78cf5..340f29bac921 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -125,7 +125,7 @@ private[spark] class MesosSchedulerBackend( val executorBackendName = classOf[MesosExecutorBackend].getName if (uri.isEmpty) { - val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath + val executorPath = new File(executorSparkHome, "/bin/spark-class").getPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to From 064b029c6a15481fc4dfb147100c19a68cd1cc95 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 1 Feb 2016 13:56:14 -0800 Subject: [PATCH 086/131] [SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch. This includes: float, boolean, short, decimal and calendar interval. Decimal is mapped to long or byte array depending on the size and calendar interval is mapped to a struct of int and long. The only remaining type is map. The schema mapping is straightforward but we might want to revisit how we deal with this in the rest of the execution engine. Author: Nong Li Closes #10961 from nongli/spark-13043. --- .../apache/spark/sql/types/DecimalType.scala | 22 +++ .../execution/vectorized/ColumnVector.java | 180 +++++++++++++++++- .../vectorized/ColumnVectorUtils.java | 34 +++- .../execution/vectorized/ColumnarBatch.java | 46 +++-- .../vectorized/OffHeapColumnVector.java | 98 +++++++++- .../vectorized/OnHeapColumnVector.java | 94 ++++++++- .../vectorized/ColumnarBatchSuite.scala | 44 ++++- .../org/apache/spark/unsafe/Platform.java | 8 + 8 files changed, 484 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index cf5322125bd7..5dd661ee6b33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a long + */ + def is64BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_LONG_DIGITS + case _ => false + } + } + + /** + * Returns if dt is a DecimalType that doesn't fit inside a long + */ + def isByteArrayDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision > Decimal.MAX_LONG_DIGITS + case _ => false + } + } + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a0bf8734b654..a5bc506a65ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -102,18 +105,36 @@ public Object[] array() { DataType dt = data.dataType(); Object[] list = new Object[length]; - if (dt instanceof ByteType) { + if (dt instanceof BooleanType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getBoolean(offset + i); + } + } + } else if (dt instanceof ByteType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getByte(offset + i); } } + } else if (dt instanceof ShortType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getShort(offset + i); + } + } } else if (dt instanceof IntegerType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getInt(offset + i); } } + } else if (dt instanceof FloatType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getFloat(offset + i); + } + } } else if (dt instanceof DoubleType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { @@ -126,12 +147,25 @@ public Object[] array() { list[i] = data.getLong(offset + i); } } + } else if (dt instanceof DecimalType) { + DecimalType decType = (DecimalType)dt; + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getDecimal(i, decType.precision(), decType.scale()); + } + } } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); } } + } else if (dt instanceof CalendarIntervalType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getInterval(i); + } + } } else { throw new NotImplementedException("Type " + dt); } @@ -170,7 +204,14 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -181,17 +222,22 @@ public UTF8String getUTF8String(int ordinal) { @Override public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = data.getByteArray(offset + ordinal); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); } @Override public InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); + return data.getStruct(offset + ordinal); } @Override @@ -279,6 +325,21 @@ public void reset() { */ public abstract boolean getIsNull(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Returns the value for rowId. + */ + public abstract boolean getBoolean(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -299,6 +360,26 @@ public void reset() { */ public abstract byte getByte(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract short getShort(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -351,6 +432,33 @@ public void reset() { */ public abstract long getLong(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract float getFloat(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -369,7 +477,7 @@ public void reset() { /** * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formated doubles. + * The data in src must be ieee formatted doubles. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -469,6 +577,20 @@ public final int appendNotNulls(int count) { return result; } + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendByte(byte v) { reserve(elementsAppended + 1); putByte(elementsAppended, v); @@ -491,6 +613,28 @@ public final int appendBytes(int length, byte[] src, int offset) { return result; } + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendInt(int v) { reserve(elementsAppended + 1); putInt(elementsAppended, v); @@ -535,6 +679,20 @@ public final int appendLongs(int length, long[] src, int offset) { return result; } + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); @@ -661,7 +819,8 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.capacity = capacity; this.type = type; - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) { + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { @@ -682,6 +841,13 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new ColumnVector[2]; + this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 6c651a759d25..453bc15e1350 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Iterator; import java.util.List; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.commons.lang.NotImplementedException; @@ -59,19 +62,44 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) { private static void appendValue(ColumnVector dst, DataType t, Object o) { if (o == null) { - dst.appendNull(); + if (t instanceof CalendarIntervalType) { + dst.appendStruct(true); + } else { + dst.appendNull(); + } } else { - if (t == DataTypes.ByteType) { - dst.appendByte(((Byte)o).byteValue()); + if (t == DataTypes.BooleanType) { + dst.appendBoolean(((Boolean)o).booleanValue()); + } else if (t == DataTypes.ByteType) { + dst.appendByte(((Byte) o).byteValue()); + } else if (t == DataTypes.ShortType) { + dst.appendShort(((Short)o).shortValue()); } else if (t == DataTypes.IntegerType) { dst.appendInt(((Integer)o).intValue()); } else if (t == DataTypes.LongType) { dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.FloatType) { + dst.appendFloat(((Float)o).floatValue()); } else if (t == DataTypes.DoubleType) { dst.appendDouble(((Double)o).doubleValue()); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(); dst.appendByteArray(b, 0, b.length); + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + dst.appendLong(d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + dst.appendByteArray(bytes, 0, bytes.length); + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)o; + dst.appendStruct(false); + dst.getChildColumn(0).appendInt(c.months); + dst.getChildColumn(1).appendLong(c.microseconds); } else { throw new NotImplementedException("Type " + t); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 5a575811fa89..dbad5e070f1f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -150,44 +153,40 @@ public final boolean anyNull() { } @Override - public final boolean isNullAt(int ordinal) { - return columns[ordinal].getIsNull(rowId); - } + public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); } @Override - public final boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } + public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } @Override public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override - public final short getShort(int ordinal) { - throw new NotImplementedException(); - } + public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } @Override - public final int getInt(int ordinal) { - return columns[ordinal].getInt(rowId); - } + public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } @Override public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override - public final float getFloat(int ordinal) { - throw new NotImplementedException(); - } + public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } @Override - public final double getDouble(int ordinal) { - return columns[ordinal].getDouble(rowId); - } + public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -198,12 +197,17 @@ public final UTF8String getUTF8String(int ordinal) { @Override public final byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = columns[ordinal].getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public final CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 335124fd5a60..22c5e5fc81a4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,11 +19,15 @@ import java.nio.ByteOrder; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; import org.apache.spark.sql.types.IntegerType; import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -121,6 +125,26 @@ public final boolean getIsNull(int rowId) { return Platform.getByte(null, nulls + rowId) == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0)); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, v); + } + } + + @Override + public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + // // APIs dealing with Bytes // @@ -148,6 +172,34 @@ public final byte getByte(int rowId) { return Platform.getByte(null, data + rowId); } + // + // APIs dealing with shorts + // + + @Override + public final void putShort(int rowId, short value) { + Platform.putShort(null, data + 2 * rowId, value); + } + + @Override + public final void putShorts(int rowId, int count, short value) { + long offset = data + 2 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putShort(null, offset, value); + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, + null, data + 2 * rowId, count * 2); + } + + @Override + public final short getShort(int rowId) { + return Platform.getShort(null, data + 2 * rowId); + } + // // APIs dealing with ints // @@ -216,6 +268,41 @@ public final long getLong(int rowId) { return Platform.getLong(null, data + 8 * rowId); } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { + Platform.putFloat(null, data + rowId * 4, value); + } + + @Override + public final void putFloats(int rowId, int count, float value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putFloat(null, offset, value); + } + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { + return Platform.getFloat(null, data + rowId * 4); + } + + // // APIs dealing with doubles // @@ -241,7 +328,7 @@ public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { @Override public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex, + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId * 8, count * 8); } @@ -300,11 +387,14 @@ private final void reserveInternal(int newCapacity) { Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); this.offsetData = Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof ByteType) { + } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); - } else if (type instanceof IntegerType) { + } else if (type instanceof ShortType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + } else if (type instanceof IntegerType || type instanceof FloatType) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof LongType || type instanceof DoubleType) { + } else if (type instanceof LongType || type instanceof DoubleType || + DecimalType.is64BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 8197fa11cd4c..32356334c031 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector { // Array for each type. Only 1 is populated for any type. private byte[] byteData; + private short[] shortData; private int[] intData; private long[] longData; + private float[] floatData; private double[] doubleData; // Only set if type is Array. @@ -104,6 +106,30 @@ public final boolean getIsNull(int rowId) { return nulls[rowId] == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + byteData[rowId] = (byte)((value) ? 1 : 0); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = v; + } + } + + @Override + public final boolean getBoolean(int rowId) { + return byteData[rowId] == 1; + } + + // + // // APIs dealing with Bytes // @@ -130,6 +156,33 @@ public final byte getByte(int rowId) { return byteData[rowId]; } + // + // APIs dealing with Shorts + // + + @Override + public final void putShort(int rowId, short value) { + shortData[rowId] = value; + } + + @Override + public final void putShorts(int rowId, int count, short value) { + for (int i = 0; i < count; ++i) { + shortData[i + rowId] = value; + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + System.arraycopy(src, srcIndex, shortData, rowId, count); + } + + @Override + public final short getShort(int rowId) { + return shortData[rowId]; + } + + // // APIs dealing with Ints // @@ -202,6 +255,31 @@ public final long getLong(int rowId) { return longData[rowId]; } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { floatData[rowId] = value; } + + @Override + public final void putFloats(int rowId, int count, float value) { + Arrays.fill(floatData, rowId, rowId + count, value); + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + System.arraycopy(src, srcIndex, floatData, rowId, count); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { return floatData[rowId]; } // // APIs dealing with doubles @@ -277,7 +355,7 @@ public final void reserve(int requiredCapacity) { // Spilt this function out since it is the slow path. private final void reserveInternal(int newCapacity) { - if (this.resultArray != null) { + if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { @@ -286,18 +364,30 @@ private final void reserveInternal(int newCapacity) { } arrayLengths = newLengths; arrayOffsets = newOffsets; + } else if (type instanceof BooleanType) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; } else if (type instanceof ByteType) { byte[] newData = new byte[newCapacity]; if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); byteData = newData; + } else if (type instanceof ShortType) { + short[] newData = new short[newCapacity]; + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + shortData = newData; } else if (type instanceof IntegerType) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; - } else if (type instanceof LongType) { + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); longData = newData; + } else if (type instanceof FloatType) { + float[] newData = new float[newCapacity]; + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + floatData = newData; } else if (type instanceof DoubleType) { double[] newData = new double[newCapacity]; if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 67cc08b6fc8b..445f311107e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { test("Null Apis") { @@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } - private def doubleEquals(d1: Double, d2: Double): Boolean = { if (d1.isNaN && d2.isNaN) { true @@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) if (!r1.isNullAt(v._2)) { v._1.dataType match { + case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed) case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed) case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)), + "Seed = " + seed) case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), "Seed = " + seed) + case t: DecimalType => + val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal + val d2 = r2.getDecimal(v._2) + assert(d1.compare(d2) == 0, "Seed = " + seed) case StringType => assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case CalendarIntervalType => + assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval]) case ArrayType(childType, n) => val a1 = r1.getArray(v._2).array val a2 = r2.getList(v._2).toArray @@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } } + case FloatType => { + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), + "Seed = " + seed) + i += 1 + } + } + + case t: DecimalType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal + val d2 = a2(i).asInstanceOf[java.math.BigDecimal] + assert(d1.compare(d2) == 0, "Seed = " + seed) + } + i += 1 + } + case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => @@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite { * results. */ def testRandomRows(flatSchema: Boolean, numFields: Int) { - // TODO: add remaining types. Figure out why StringType doesn't work on jenkins. - val types = Array(ByteType, IntegerType, LongType, DoubleType) + // TODO: Figure out why StringType doesn't work on jenkins. + val types = Array( + BooleanType, ByteType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10), + CalendarIntervalType) val seed = System.nanoTime() - val NUM_ROWS = 500 + val NUM_ROWS = 200 val NUM_ITERS = 1000 val random = new Random(seed) var i = 0 @@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("Random flat schema") { - testRandomRows(true, 10) + testRandomRows(true, 15) } test("Random nested schema") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b29bf6a464b3..18761bfd222a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -27,10 +27,14 @@ public final class Platform { public static final int BYTE_ARRAY_OFFSET; + public static final int SHORT_ARRAY_OFFSET; + public static final int INT_ARRAY_OFFSET; public static final int LONG_ARRAY_OFFSET; + public static final int FLOAT_ARRAY_OFFSET; + public static final int DOUBLE_ARRAY_OFFSET; public static int getInt(Object object, long offset) { @@ -168,13 +172,17 @@ public static void throwException(Throwable t) { if (_UNSAFE != null) { BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); } else { BYTE_ARRAY_OFFSET = 0; + SHORT_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; + FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; } } From a2973fed30fbe9a0b12e1c1225359fdf55d322b4 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 1 Feb 2016 13:57:48 -0800 Subject: [PATCH 087/131] =?UTF-8?q?Fix=20for=20[SPARK-12854][SQL]=20Implem?= =?UTF-8?q?ent=20complex=20types=20support=20in=20Columna=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rBatch Fixes build for Scala 2.11. Author: Jacek Laskowski Closes #10946 from jaceklaskowski/SPARK-12854-fix. --- .../spark/sql/execution/vectorized/OffHeapColumnVector.java | 2 +- .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 22c5e5fc81a4..7a224d19d15b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -367,7 +367,7 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) { } @Override - public final void loadBytes(Array array) { + public final void loadBytes(ColumnVector.Array array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 32356334c031..c42bbd642eca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -331,7 +331,7 @@ public final void putArray(int rowId, int offset, int length) { } @Override - public final void loadBytes(Array array) { + public final void loadBytes(ColumnVector.Array array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } From be7a2fc0716b7d25327b6f8f683390fc62532e3b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 14:11:52 -0800 Subject: [PATCH 088/131] [SPARK-13078][SQL] API and test cases for internal catalog This pull request creates an internal catalog API. The creation of this API is the first step towards consolidating SQLContext and HiveContext. I envision we will have two different implementations in Spark 2.0: (1) a simple in-memory implementation, and (2) an implementation based on the current HiveClient (ClientWrapper). I took a look at what Hive's internal metastore interface/implementation, and then created this API based on it. I believe this is the minimal set needed in order to achieve all the needed functionality. Author: Reynold Xin Closes #10982 from rxin/SPARK-13078. --- .../catalyst/catalog/InMemoryCatalog.scala | 246 ++++++++++++++++ .../sql/catalyst/catalog/interface.scala | 178 ++++++++++++ .../catalyst/catalog/CatalogTestCases.scala | 263 ++++++++++++++++++ .../catalog/InMemoryCatalogSuite.scala | 23 ++ 4 files changed, 710 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala new file mode 100644 index 000000000000..9e6dfb7e9506 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException + + +/** + * An in-memory (ephemeral) implementation of the system catalog. + * + * All public methods should be synchronized for thread-safety. + */ +class InMemoryCatalog extends Catalog { + + private class TableDesc(var table: Table) { + val partitions = new mutable.HashMap[String, TablePartition] + } + + private class DatabaseDesc(var db: Database) { + val tables = new mutable.HashMap[String, TableDesc] + val functions = new mutable.HashMap[String, Function] + } + + private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] + + private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val regex = pattern.replaceAll("\\*", ".*").r + names.filter { funcName => regex.pattern.matcher(funcName).matches() } + } + + private def existsFunction(db: String, funcName: String): Boolean = { + catalog(db).functions.contains(funcName) + } + + private def existsTable(db: String, table: String): Boolean = { + catalog(db).tables.contains(table) + } + + private def assertDbExists(db: String): Unit = { + if (!catalog.contains(db)) { + throw new AnalysisException(s"Database $db does not exist") + } + } + + private def assertFunctionExists(db: String, funcName: String): Unit = { + assertDbExists(db) + if (!existsFunction(db, funcName)) { + throw new AnalysisException(s"Function $funcName does not exists in $db database") + } + } + + private def assertTableExists(db: String, table: String): Unit = { + assertDbExists(db) + if (!existsTable(db, table)) { + throw new AnalysisException(s"Table $table does not exists in $db database") + } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized { + if (catalog.contains(dbDefinition.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Database ${dbDefinition.name} already exists.") + } + } else { + catalog.put(dbDefinition.name, new DatabaseDesc(dbDefinition)) + } + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { + if (catalog.contains(db)) { + if (!cascade) { + // If cascade is false, make sure the database is empty. + if (catalog(db).tables.nonEmpty) { + throw new AnalysisException(s"Database $db is not empty. One or more tables exist.") + } + if (catalog(db).functions.nonEmpty) { + throw new AnalysisException(s"Database $db is not empty. One or more functions exist.") + } + } + // Remove the database. + catalog.remove(db) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Database $db does not exist") + } + } + } + + override def alterDatabase(db: String, dbDefinition: Database): Unit = synchronized { + assertDbExists(db) + assert(db == dbDefinition.name) + catalog(db).db = dbDefinition + } + + override def getDatabase(db: String): Database = synchronized { + assertDbExists(db) + catalog(db).db + } + + override def listDatabases(): Seq[String] = synchronized { + catalog.keySet.toSeq + } + + override def listDatabases(pattern: String): Seq[String] = synchronized { + filterPattern(listDatabases(), pattern) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean) + : Unit = synchronized { + assertDbExists(db) + if (existsTable(db, tableDefinition.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") + } + } else { + catalog(db).tables.put(tableDefinition.name, new TableDesc(tableDefinition)) + } + } + + override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean) + : Unit = synchronized { + assertDbExists(db) + if (existsTable(db, table)) { + catalog(db).tables.remove(table) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Table $table does not exist in $db database") + } + } + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + assertTableExists(db, oldName) + val oldDesc = catalog(db).tables(oldName) + oldDesc.table = oldDesc.table.copy(name = newName) + catalog(db).tables.put(newName, oldDesc) + catalog(db).tables.remove(oldName) + } + + override def alterTable(db: String, table: String, tableDefinition: Table): Unit = synchronized { + assertTableExists(db, table) + assert(table == tableDefinition.name) + catalog(db).tables(table).table = tableDefinition + } + + override def getTable(db: String, table: String): Table = synchronized { + assertTableExists(db, table) + catalog(db).tables(table).table + } + + override def listTables(db: String): Seq[String] = synchronized { + assertDbExists(db) + catalog(db).tables.keySet.toSeq + } + + override def listTables(db: String, pattern: String): Seq[String] = synchronized { + assertDbExists(db) + filterPattern(listTables(db), pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def alterPartition(db: String, table: String, part: TablePartition) + : Unit = synchronized { + throw new UnsupportedOperationException + } + + override def alterPartitions(db: String, table: String, parts: Seq[TablePartition]) + : Unit = synchronized { + throw new UnsupportedOperationException + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction( + db: String, func: Function, ifNotExists: Boolean): Unit = synchronized { + assertDbExists(db) + + if (existsFunction(db, func.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Function $func already exists in $db database") + } + } else { + catalog(db).functions.put(func.name, func) + } + } + + override def dropFunction(db: String, funcName: String): Unit = synchronized { + assertFunctionExists(db, funcName) + catalog(db).functions.remove(funcName) + } + + override def alterFunction(db: String, funcName: String, funcDefinition: Function) + : Unit = synchronized { + assertFunctionExists(db, funcName) + if (funcName != funcDefinition.name) { + // Also a rename; remove the old one and add the new one back + catalog(db).functions.remove(funcName) + } + catalog(db).functions.put(funcName, funcDefinition) + } + + override def getFunction(db: String, funcName: String): Function = synchronized { + assertFunctionExists(db, funcName) + catalog(db).functions(funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { + assertDbExists(db) + val regex = pattern.replaceAll("\\*", ".*").r + filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala new file mode 100644 index 000000000000..a6caf91f3304 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.AnalysisException + + +/** + * Interface for the system catalog (of columns, partitions, tables, and databases). + * + * This is only used for non-temporary items, and implementations must be thread-safe as they + * can be accessed in multiple threads. + * + * Implementations should throw [[AnalysisException]] when table or database don't exist. + */ +abstract class Catalog { + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit + + def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit + + def alterDatabase(db: String, dbDefinition: Database): Unit + + def getDatabase(db: String): Database + + def listDatabases(): Seq[String] + + def listDatabases(pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + def createTable(db: String, tableDefinition: Table, ignoreIfExists: Boolean): Unit + + def dropTable(db: String, table: String, ignoreIfNotExists: Boolean): Unit + + def renameTable(db: String, oldName: String, newName: String): Unit + + def alterTable(db: String, table: String, tableDefinition: Table): Unit + + def getTable(db: String, table: String): Table + + def listTables(db: String): Seq[String] + + def listTables(db: String, pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + // TODO: need more functions for partitioning. + + def alterPartition(db: String, table: String, part: TablePartition): Unit + + def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + def createFunction(db: String, funcDefinition: Function, ignoreIfExists: Boolean): Unit + + def dropFunction(db: String, funcName: String): Unit + + def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit + + def getFunction(db: String, funcName: String): Function + + def listFunctions(db: String, pattern: String): Seq[String] + +} + + +/** + * A function defined in the catalog. + * + * @param name name of the function + * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" + */ +case class Function( + name: String, + className: String +) + + +/** + * Storage format, used to describe how a partition or a table is stored. + */ +case class StorageFormat( + locationUri: String, + inputFormat: String, + outputFormat: String, + serde: String, + serdeProperties: Map[String, String] +) + + +/** + * A column in a table. + */ +case class Column( + name: String, + dataType: String, + nullable: Boolean, + comment: String +) + + +/** + * A partition (Hive style) defined in the catalog. + * + * @param values values for the partition columns + * @param storage storage format of the partition + */ +case class TablePartition( + values: Seq[String], + storage: StorageFormat +) + + +/** + * A table defined in the catalog. + * + * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the + * future once we have a better understanding of how we want to handle skewed columns. + */ +case class Table( + name: String, + description: String, + schema: Seq[Column], + partitionColumns: Seq[Column], + sortColumns: Seq[Column], + storage: StorageFormat, + numBuckets: Int, + properties: Map[String, String], + tableType: String, + createTime: Long, + lastAccessTime: Long, + viewOriginalText: Option[String], + viewText: Option[String]) { + + require(tableType == "EXTERNAL_TABLE" || tableType == "INDEX_TABLE" || + tableType == "MANAGED_TABLE" || tableType == "VIRTUAL_VIEW") +} + + +/** + * A database defined in the catalog. + */ +case class Database( + name: String, + description: String, + locationUri: String, + properties: Map[String, String] +) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala new file mode 100644 index 000000000000..ab9d5ac8a20e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException + + +/** + * A reasonable complete test suite (i.e. behaviors) for a [[Catalog]]. + * + * Implementations of the [[Catalog]] interface can create test suites by extending this. + */ +abstract class CatalogTestCases extends SparkFunSuite { + + protected def newEmptyCatalog(): Catalog + + /** + * Creates a basic catalog, with the following structure: + * + * db1 + * db2 + * - tbl1 + * - tbl2 + * - func1 + */ + private def newBasicCatalog(): Catalog = { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("db1"), ifNotExists = false) + catalog.createDatabase(newDb("db2"), ifNotExists = false) + + catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog + } + + private def newFunc(): Function = Function("funcname", "org.apache.spark.MyFunc") + + private def newDb(name: String = "default"): Database = + Database(name, name + " description", "uri", Map.empty) + + private def newTable(name: String): Table = + Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, + None, None) + + private def newFunc(name: String): Function = Function(name, "class.name") + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + test("basic create, drop and list databases") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb(), ifNotExists = false) + assert(catalog.listDatabases().toSet == Set("default")) + + catalog.createDatabase(newDb("default2"), ifNotExists = false) + assert(catalog.listDatabases().toSet == Set("default", "default2")) + } + + test("get database when a database exists") { + val db1 = newBasicCatalog().getDatabase("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } + + test("get database should throw exception when the database does not exist") { + intercept[AnalysisException] { newBasicCatalog().getDatabase("db_that_does_not_exist") } + } + + test("list databases without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases().toSet == Set("db1", "db2")) + } + + test("list databases with pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } + + test("drop database") { + val catalog = newBasicCatalog() + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("db2")) + } + + test("drop database when the database is not empty") { + // Throw exception if there are functions left + val catalog1 = newBasicCatalog() + catalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false) + catalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // Throw exception if there are tables left + val catalog2 = newBasicCatalog() + catalog2.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // When cascade is true, it should drop them + val catalog3 = newBasicCatalog() + catalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog3.listDatabases().toSet == Set("db1")) + } + + test("drop database when the database does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) + } + + test("alter database") { + val catalog = newBasicCatalog() + catalog.alterDatabase("db1", Database("db1", "new description", "lll", Map.empty)) + assert(catalog.getDatabase("db1").description == "new description") + } + + test("alter database should throw exception when the database does not exist") { + intercept[AnalysisException] { + newBasicCatalog().alterDatabase("no_db", Database("no_db", "ddd", "lll", Map.empty)) + } + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + test("drop table") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable("db2", "tbl1", ignoreIfNotExists = false) + assert(catalog.listTables("db2").toSet == Set("tbl2")) + } + + test("drop table when database / table does not exist") { + val catalog = newBasicCatalog() + + // Should always throw exception when the database does not exist + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = false) + } + + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = true) + } + + // Should throw exception when the table does not exist, if ignoreIfNotExists is false + intercept[AnalysisException] { + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = false) + } + + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = true) + } + + test("rename table") { + val catalog = newBasicCatalog() + + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable("db2", "tbl1", "tblone") + assert(catalog.listTables("db2").toSet == Set("tblone", "tbl2")) + } + + test("rename table when database / table does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { // Throw exception when the database does not exist + catalog.renameTable("unknown_db", "unknown_table", "unknown_table") + } + + intercept[AnalysisException] { // Throw exception when the table does not exist + catalog.renameTable("db2", "unknown_table", "unknown_table") + } + } + + test("alter table") { + val catalog = newBasicCatalog() + catalog.alterTable("db2", "tbl1", newTable("tbl1").copy(createTime = 10)) + assert(catalog.getTable("db2", "tbl1").createTime == 10) + } + + test("alter table when database / table does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { // Throw exception when the database does not exist + catalog.alterTable("unknown_db", "unknown_table", newTable("unknown_table")) + } + + intercept[AnalysisException] { // Throw exception when the table does not exist + catalog.alterTable("db2", "unknown_table", newTable("unknown_table")) + } + } + + test("get table") { + assert(newBasicCatalog().getTable("db2", "tbl1").name == "tbl1") + } + + test("get table when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getTable("unknown_db", "unknown_table") + } + + intercept[AnalysisException] { + catalog.getTable("db2", "unknown_table") + } + } + + test("list tables without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db1").toSet == Set.empty) + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + } + + test("list tables with pattern") { + val catalog = newBasicCatalog() + + // Test when database does not exist + intercept[AnalysisException] { catalog.listTables("unknown_db") } + + assert(catalog.listTables("db1", "*").toSet == Set.empty) + assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "*1").toSet == Set("tbl1")) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + // TODO: Add tests cases for partitions + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + // TODO: Add tests cases for functions +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala new file mode 100644 index 000000000000..871f0a0f46a2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +/** Test suite for the [[InMemoryCatalog]]. */ +class InMemoryCatalogSuite extends CatalogTestCases { + override protected def newEmptyCatalog(): Catalog = new InMemoryCatalog +} From 715a19d56fc934d4aec5025739ff650daf4580b7 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 Feb 2016 16:23:17 -0800 Subject: [PATCH 089/131] [SPARK-12637][CORE] Print stage info of finished stages properly Improve printing of StageInfo in onStageCompleted See also https://github.com/apache/spark/pull/10585 Author: Sean Owen Closes #10922 from srowen/SPARK-12637. --- .../org/apache/spark/scheduler/SparkListener.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index ed3adbd81c28..7b09c2eded0b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -270,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) + this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) // Shuffle write @@ -297,6 +297,17 @@ class StatsReportListener extends SparkListener with Logging { taskInfoMetrics.clear() } + private def getStatusDetail(info: StageInfo): String = { + val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("") + val timeTaken = info.submissionTime.map( + x => info.completionTime.getOrElse(System.currentTimeMillis()) - x + ).getOrElse("-") + + s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + + s"Took: $timeTaken msec" + } + } private[spark] object StatsReportListener extends Logging { From 0df3cfb8ab4d584c95db6c340694e199d7b59e9e Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 1 Feb 2016 16:55:21 -0800 Subject: [PATCH 090/131] [SPARK-12790][CORE] Remove HistoryServer old multiple files format Removed isLegacyLogDirectory code path and updated tests andrewor14 Author: felixcheung Closes #10860 from felixcheung/historyserverformat. --- .rat-excludes | 12 +- .../deploy/history/FsHistoryProvider.scala | 124 ++---------------- .../scheduler/EventLoggingListener.scala | 2 - .../EVENT_LOG_1 => local-1422981759269} | 0 .../local-1422981759269/APPLICATION_COMPLETE | 0 .../local-1422981759269/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1422981780767} | 0 .../local-1422981780767/APPLICATION_COMPLETE | 0 .../local-1422981780767/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1425081759269} | 0 .../local-1425081759269/APPLICATION_COMPLETE | 0 .../local-1425081759269/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1426533911241} | 0 .../local-1426533911241/APPLICATION_COMPLETE | 0 .../local-1426533911241/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1426633911242} | 0 .../local-1426633911242/APPLICATION_COMPLETE | 0 .../local-1426633911242/SPARK_VERSION_1.2.0 | 0 .../history/FsHistoryProviderSuite.scala | 95 +------------- .../deploy/history/HistoryServerSuite.scala | 25 +--- 20 files changed, 23 insertions(+), 235 deletions(-) rename core/src/test/resources/spark-events/{local-1422981759269/EVENT_LOG_1 => local-1422981759269} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1422981780767/EVENT_LOG_1 => local-1422981780767} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1425081759269/EVENT_LOG_1 => local-1425081759269} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1426533911241/EVENT_LOG_1 => local-1426533911241} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1426633911242/EVENT_LOG_1 => local-1426633911242} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 diff --git a/.rat-excludes b/.rat-excludes index 874a6ee9f404..8b5061415ff4 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -73,12 +73,12 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation -local-1422981759269/* -local-1422981780767/* -local-1425081759269/* -local-1426533911241/* -local-1426633911242/* -local-1430917381534/* +local-1422981759269 +local-1422981780767 +local-1425081759269 +local-1426533911241 +local-1426633911242 +local-1430917381534 local-1430917381535_1 local-1430917381535_2 DESCRIPTION diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 22e4155cc545..9648959dbacb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -248,9 +248,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val logInfos: Seq[FileStatus] = statusList .filter { entry => try { - getModificationTime(entry).map { time => - time >= lastScanTime - }.getOrElse(false) + !entry.isDirectory() && (entry.getModificationTime() >= lastScanTime) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -261,9 +259,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => - val mod1 = getModificationTime(entry1).getOrElse(-1L) - val mod2 = getModificationTime(entry2).getOrElse(-1L) - mod1 >= mod2 + entry1.getModificationTime() >= entry2.getModificationTime() } logInfos.grouped(20) @@ -341,19 +337,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get }.foreach { attempt => val logPath = new Path(logDir, attempt.logPath) - // If this is a legacy directory, then add the directory to the zipStream and add - // each file to that directory. - if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { - val files = fs.listStatus(logPath) - zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) - zipStream.closeEntry() - files.foreach { file => - val path = file.getPath - zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) - } - } else { - zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) - } + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) } } finally { zipStream.close() @@ -527,12 +511,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") - val logInput = - if (isLegacyLogDirectory(eventLog)) { - openLegacyEventLog(logPath) - } else { - EventLoggingListener.openEventLog(logPath, fs) - } + val logInput = EventLoggingListener.openEventLog(logPath, fs) try { val appListener = new ApplicationEventListener val appCompleted = isApplicationCompleted(eventLog) @@ -540,9 +519,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.replay(logInput, logPath.toString, !appCompleted) // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. Some old versions of Spark generate logs without an app ID, so let - // logs generated by those versions go through. - if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + // try to show their UI. + if (appListener.appId.isDefined) { Some(new FsApplicationAttemptInfo( logPath.getName(), appListener.appName.getOrElse(NOT_STARTED), @@ -550,7 +528,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.appAttemptId, appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, + eventLog.getModificationTime(), appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted)) } else { @@ -561,91 +539,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Loads a legacy log directory. This assumes that the log directory contains a single event - * log file (along with other metadata files), which is the case for directories generated by - * the code in previous releases. - * - * @return input stream that holds one JSON record per line. - */ - private[history] def openLegacyEventLog(dir: Path): InputStream = { - val children = fs.listStatus(dir) - var eventLogPath: Path = null - var codecName: Option[String] = None - - children.foreach { child => - child.getPath().getName() match { - case name if name.startsWith(LOG_PREFIX) => - eventLogPath = child.getPath() - case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) => - codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length())) - case _ => - } - } - - if (eventLogPath == null) { - throw new IllegalArgumentException(s"$dir is not a Spark application log directory.") - } - - val codec = try { - codecName.map { c => CompressionCodec.createCodec(conf, c) } - } catch { - case e: Exception => - throw new IllegalArgumentException(s"Unknown compression codec $codecName.") - } - - val in = new BufferedInputStream(fs.open(eventLogPath)) - codec.map(_.compressedInputStream(in)).getOrElse(in) - } - - /** - * Return whether the specified event log path contains a old directory-based event log. - * Previously, the event log of an application comprises of multiple files in a directory. - * As of Spark 1.3, these files are consolidated into a single one that replaces the directory. - * See SPARK-2261 for more detail. - */ - private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDirectory - - /** - * Returns the modification time of the given event log. If the status points at an empty - * directory, `None` is returned, indicating that there isn't an event log at that location. - */ - private def getModificationTime(fsEntry: FileStatus): Option[Long] = { - if (isLegacyLogDirectory(fsEntry)) { - val statusList = fs.listStatus(fsEntry.getPath) - if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None - } else { - Some(fsEntry.getModificationTime()) - } - } - /** * Return true when the application has completed. */ private def isApplicationCompleted(entry: FileStatus): Boolean = { - if (isLegacyLogDirectory(entry)) { - fs.exists(new Path(entry.getPath(), APPLICATION_COMPLETE)) - } else { - !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS) - } - } - - /** - * Returns whether the version of Spark that generated logs records app IDs. App IDs were added - * in Spark 1.1. - */ - private def sparkVersionHasAppId(entry: FileStatus): Boolean = { - if (isLegacyLogDirectory(entry)) { - fs.listStatus(entry.getPath()) - .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } - .map { status => - val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) - version != "1.0" && version != "1.1" - } - .getOrElse(true) - } else { - true - } + !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS) } /** @@ -670,12 +568,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" - - // Constants used to parse Spark 1.0.0 log directories. - val LOG_PREFIX = "EVENT_LOG_" - val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 36f2b74f948f..01fee46e73a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -232,8 +232,6 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" - val SPARK_VERSION_KEY = "SPARK_VERSION" - val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) diff --git a/core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981759269 diff --git a/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981780767 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981780767 diff --git a/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1425081759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1425081759269 diff --git a/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426533911241 similarity index 100% rename from core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426533911241 diff --git a/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426633911242 similarity index 100% rename from core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426633911242 diff --git a/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1..000000000000 diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 6cbf911395a8..3baa2e2ddad3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -69,7 +69,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new File(logPath) } - test("Parse new and old application logs") { + test("Parse application logs") { val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. @@ -95,26 +95,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc None) ) - // Write an old-style application log. - val oldAppComplete = writeOldLog("old1", "1.0", None, true, - SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - - // Check for logs so that we force the older unfinished app to be loaded, to make - // sure unfinished apps are also sorted correctly. - provider.checkForLogs() - - // Write an unfinished app, old-style. - val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, - SparkListenerApplicationStart("old2", None, 2L, "test", None) - ) - - // Force a reload of data from the log directory, and check that both logs are loaded. + // Force a reload of data from the log directory, and check that logs are loaded. // Take the opportunity to check that the offset checks work as expected. updateAndCheck(provider) { list => - list.size should be (5) - list.count(_.attempts.head.completed) should be (3) + list.size should be (3) + list.count(_.attempts.head.completed) should be (2) def makeAppInfo( id: String, @@ -132,11 +117,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc newAppComplete.lastModified(), "test", true)) list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) - list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, - oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, - -1L, oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, + list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -148,38 +129,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("Parse legacy logs with compression codec set") { - val provider = new FsHistoryProvider(createTestConf()) - val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), - (classOf[SnappyCompressionCodec].getName(), true), - ("invalid.codec", false)) - - testCodecs.foreach { case (codecName, valid) => - val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null - val logDir = new File(testDir, codecName) - logDir.mkdir() - createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), - SparkListenerApplicationStart("app2", None, 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) - - val logPath = new Path(logDir.getAbsolutePath()) - try { - val logInput = provider.openLegacyEventLog(logPath) - try { - Source.fromInputStream(logInput).getLines().toSeq.size should be (2) - } finally { - logInput.close() - } - } catch { - case e: IllegalArgumentException => - valid should be (false) - } - } - } - test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, @@ -395,21 +344,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc SparkListenerLogStart("1.4") ) - // Write a 1.2 log file with no start event (= no app id), it should be ignored. - writeOldLog("v12Log", "1.2", None, false) - - // Write 1.0 and 1.1 logs, which don't have app ids. - writeOldLog("v11Log", "1.1", None, true, - SparkListenerApplicationStart("v11Log", None, 2L, "test", None), - SparkListenerApplicationEnd(3L)) - writeOldLog("v10Log", "1.0", None, true, - SparkListenerApplicationStart("v10Log", None, 2L, "test", None), - SparkListenerApplicationEnd(4L)) - updateAndCheck(provider) { list => - list.size should be (2) - list(0).id should be ("v10Log") - list(1).id should be ("v11Log") + list.size should be (0) } } @@ -499,25 +435,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } - private def writeOldLog( - fname: String, - sparkVersion: String, - codec: Option[CompressionCodec], - completed: Boolean, - events: SparkListenerEvent*): File = { - val log = new File(testDir, fname) - log.mkdir() - - val oldEventLog = new File(log, LOG_PREFIX + "1") - createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) - writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) - if (completed) { - createEmptyFile(new File(log, APPLICATION_COMPLETE)) - } - - log - } - private class SafeModeTestProvider(conf: SparkConf, clock: Clock) extends FsHistoryProvider(conf, clock) { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index be55b2e0fe1b..40d0076eecfc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -176,18 +176,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } } - test("download legacy logs - all attempts") { - doDownloadTest("local-1426533911241", None, legacy = true) - } - - test("download legacy logs - single attempts") { - (1 to 2). foreach { - attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) - } - } - // Test that the files are downloaded correctly, and validate them. - def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { + def doDownloadTest(appId: String, attemptId: Option[Int]): Unit = { val url = attemptId match { case Some(id) => @@ -205,22 +195,13 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers var entry = zipStream.getNextEntry entry should not be null val totalFiles = { - if (legacy) { - attemptId.map { x => 3 }.getOrElse(6) - } else { - attemptId.map { x => 1 }.getOrElse(2) - } + attemptId.map { x => 1 }.getOrElse(2) } var filesCompared = 0 while (entry != null) { if (!entry.isDirectory) { val expectedFile = { - if (legacy) { - val splits = entry.getName.split("/") - new File(new File(logDir, splits(0)), splits(1)) - } else { - new File(logDir, entry.getName) - } + new File(logDir, entry.getName) } val expected = Files.toString(expectedFile, Charsets.UTF_8) val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) From 0fff5c6e6325357a241d311e72db942c4850af34 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 23:08:11 -0800 Subject: [PATCH 091/131] [SPARK-13130][SQL] Make codegen variable names easier to read 1. Use lower case 2. Change long prefixes to something shorter (in this case I am changing only one: TungstenAggregate -> agg). Author: Reynold Xin Closes #11017 from rxin/SPARK-13130. --- .../spark/sql/execution/WholeStageCodegen.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f049..02b0f423ed43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.util.Utils /** @@ -33,6 +34,12 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { + /** Prefix used in the current operator's variable names. */ + private def variablePrefix: String = this match { + case _: TungstenAggregate => "agg" + case _ => nodeName.toLowerCase + } + /** * Whether this SparkPlan support whole stage codegen or not. */ @@ -53,7 +60,7 @@ trait CodegenSupport extends SparkPlan { */ def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent - ctx.freshNamePrefix = nodeName + ctx.freshNamePrefix = variablePrefix doProduce(ctx) } @@ -94,7 +101,7 @@ trait CodegenSupport extends SparkPlan { child: SparkPlan, input: Seq[ExprCode], row: String = null): String = { - ctx.freshNamePrefix = nodeName + ctx.freshNamePrefix = variablePrefix if (row != null) { ctx.currentVars = null ctx.INPUT_ROW = row From b8666fd0e2a797924eb2e94ac5558aba2a9b5140 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 23:37:06 -0800 Subject: [PATCH 092/131] Closes #10662. Closes #10661 From 22ba21348b28d8b1909ccde6fe17fb9e68531e5a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 16:48:59 +0800 Subject: [PATCH 093/131] [SPARK-13087][SQL] Fix group by function for sort based aggregation It is not valid to call `toAttribute` on a `NamedExpression` unless we know for sure that the child produced that `NamedExpression`. The current code worked fine when the grouping expressions were simple, but when they were a derived value this blew up at execution time. Author: Michael Armbrust Closes #11013 from marmbrus/groupByFunction-master. --- .../org/apache/spark/sql/execution/aggregate/utils.scala | 5 ++--- .../spark/sql/hive/execution/AggregationQuerySuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 83379ae90f70..1e113ccd4e13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -33,15 +33,14 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, + requiredChildDistributionExpressions = Some(groupingExpressions), + groupingExpressions = groupingExpressions, aggregateExpressions = completeAggregateExpressions, aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 3e4cf3f79e57..7a9ed1eaf3db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -193,6 +193,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te sqlContext.dropTempTable("emptyTable") } + test("group by function") { + Seq((1, 2)).toDF("a", "b").registerTempTable("data") + + checkAnswer( + sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"), + Row(1, Array(2)) :: Nil) + } + test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( From 12a20c144f14e80ef120ddcfb0b455a805a2da23 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 10:13:54 -0800 Subject: [PATCH 094/131] [SPARK-10820][SQL] Support for the continuous execution of structured queries This is a follow up to 9aadcffabd226557174f3ff566927f873c71672e that extends Spark SQL to allow users to _repeatedly_ optimize and execute structured queries. A `ContinuousQuery` can be expressed using SQL, DataFrames or Datasets. The purpose of this PR is only to add some initial infrastructure which will be extended in subsequent PRs. ## User-facing API - `sqlContext.streamFrom` and `df.streamTo` return builder objects that are analogous to the `read/write` interfaces already available to executing queries in a batch-oriented fashion. - `ContinuousQuery` provides an interface for interacting with a query that is currently executing in the background. ## Internal Interfaces - `StreamExecution` - executes streaming queries in micro-batches The following are currently internal, but public APIs will be provided in a future release. - `Source` - an interface for providers of continually arriving data. A source must have a notion of an `Offset` that monotonically tracks what data has arrived. For fault tolerance, a source must be able to replay data given a start offset. - `Sink` - an interface that accepts the results of a continuously executing query. Also responsible for tracking the offset that should be resumed from in the case of a failure. ## Testing - `MemoryStream` and `MemorySink` - simple implementations of source and sink that keep all data in memory and have methods for simulating durability failures - `StreamTest` - a framework for performing actions and checking invariants on a continuous query Author: Michael Armbrust Author: Tathagata Das Author: Josh Rosen Closes #11006 from marmbrus/structured-streaming. --- .../apache/spark/sql/ContinuousQuery.scala | 30 ++ .../org/apache/spark/sql/DataFrame.scala | 8 + .../apache/spark/sql/DataStreamReader.scala | 127 +++++++ .../apache/spark/sql/DataStreamWriter.scala | 134 +++++++ .../org/apache/spark/sql/SQLContext.scala | 8 + .../datasources/ResolvedDataSource.scala | 33 +- .../spark/sql/execution/streaming/Batch.scala | 26 ++ .../execution/streaming/CompositeOffset.scala | 67 ++++ .../sql/execution/streaming/LongOffset.scala | 33 ++ .../sql/execution/streaming/Offset.scala | 37 ++ .../spark/sql/execution/streaming/Sink.scala | 47 +++ .../sql/execution/streaming/Source.scala | 36 ++ .../execution/streaming/StreamExecution.scala | 211 +++++++++++ .../execution/streaming/StreamProgress.scala | 67 ++++ .../streaming/StreamingRelation.scala | 34 ++ .../sql/execution/streaming/memory.scala | 138 +++++++ .../apache/spark/sql/sources/interfaces.scala | 21 ++ .../org/apache/spark/sql/QueryTest.scala | 74 ++-- .../org/apache/spark/sql/StreamTest.scala | 346 ++++++++++++++++++ .../sql/streaming/DataStreamReaderSuite.scala | 166 +++++++++ .../streaming/MemorySourceStressSuite.scala | 33 ++ .../spark/sql/streaming/OffsetSuite.scala | 98 +++++ .../spark/sql/streaming/StreamSuite.scala | 84 +++++ .../spark/sql/test/SharedSQLContext.scala | 2 +- 24 files changed, 1828 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala new file mode 100644 index 000000000000..1c2c0290fc4c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * A handle to a query that is executing continuously in the background as new data arrives. + */ +trait ContinuousQuery { + + /** + * Stops the execution of this query if it is running. This method blocks until the threads + * performing execution has stopped. + */ + def stop(): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 518f9dcf94a7..6de17e5924d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1690,6 +1690,14 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) + /** + * :: Experimental :: + * Interface for starting a streaming query that will continually output results to the specified + * external sink as new data arrives. + */ + @Experimental + def streamTo: DataStreamWriter = new DataStreamWriter(this) + /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala new file mode 100644 index 000000000000..2febc93fa49d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala @@ -0,0 +1,127 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * An interface to reading streaming data. Use `sqlContext.streamFrom` to access these methods. + * + * {{{ + * val df = sqlContext.streamFrom + * .format("...") + * .open() + * }}} + */ +@Experimental +class DataStreamReader private[sql](sqlContext: SQLContext) extends Logging { + + /** + * Specifies the input data source format. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data streams (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data stream can + * skip the schema inference step, and thus speed up data reading. + * + * @since 2.0.0 + */ + def schema(schema: StructType): DataStreamReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data stream. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamReader = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds input options for the underlying data stream. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data stream. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamReader = { + this.options(options.asScala) + this + } + + /** + * Loads streaming input in as a [[DataFrame]], for data streams that don't require a path (e.g. + * external key-value stores). + * + * @since 2.0.0 + */ + def open(): DataFrame = { + val resolved = ResolvedDataSource.createSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + providerName = source, + options = extraOptions.toMap) + DataFrame(sqlContext, StreamingRelation(resolved)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def open(path: String): DataFrame = { + option("path", path).open() + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sqlContext.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala new file mode 100644 index 000000000000..b325d48fcbbb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.streaming.StreamExecution + +/** + * :: Experimental :: + * Interface used to start a streaming query query execution. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamWriter private[sql](df: DataFrame) { + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamWriter = { + this.source = source + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamWriter = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamWriter = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamWriter = { + this.options(options.asScala) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme.\ + * @since 2.0.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataStreamWriter = { + this.partitioningColumns = colNames + this + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * @since 2.0.0 + */ + def start(path: String): ContinuousQuery = { + this.extraOptions += ("path" -> path) + start() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def start(): ContinuousQuery = { + val sink = ResolvedDataSource.createSink( + df.sqlContext, + source, + extraOptions.toMap, + normalizedParCols) + + new StreamExecution(df.sqlContext, df.logicalPlan, sink) + } + + private def normalizedParCols: Seq[String] = { + partitioningColumns.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sqlContext.conf.defaultDataSourceName + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Seq[String] = Nil + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ef993c3edae3..13700be06828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -594,6 +594,14 @@ class SQLContext private[sql]( @Experimental def read: DataFrameReader = new DataFrameReader(this) + + /** + * :: Experimental :: + * Returns a [[DataStreamReader]] than can be used to access data continuously as it arrives. + */ + @Experimental + def streamFrom: DataStreamReader = new DataStreamReader(this) + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index cc8dcf59307f..e3065ac5f87d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -29,11 +29,11 @@ import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils - case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) @@ -92,6 +92,37 @@ object ResolvedDataSource extends Logging { } } + def createSource( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + providerName: String, + options: Map[String, String]): Source = { + val provider = lookupDataSource(providerName).newInstance() match { + case s: StreamSourceProvider => s + case _ => + throw new UnsupportedOperationException( + s"Data source $providerName does not support streamed reading") + } + + provider.createSource(sqlContext, options, userSpecifiedSchema) + } + + def createSink( + sqlContext: SQLContext, + providerName: String, + options: Map[String, String], + partitionColumns: Seq[String]): Sink = { + val provider = lookupDataSource(providerName).newInstance() match { + case s: StreamSinkProvider => s + case _ => + throw new UnsupportedOperationException( + s"Data source $providerName does not support streamed writing") + } + + provider.createSink(sqlContext, options, partitionColumns) + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ def apply( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala new file mode 100644 index 000000000000..1f25eb8fc522 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.DataFrame + +/** + * Used to pass a batch of data through a streaming query execution along with an indication + * of progress in the stream. + */ +class Batch(val end: Offset, val data: DataFrame) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala new file mode 100644 index 000000000000..d2cb20ef8b81 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.Try + +/** + * An ordered collection of offsets, used to track the progress of processing data from one or more + * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance + * vector clock that must progress linearly forward. + */ +case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + override def compareTo(other: Offset): Int = other match { + case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => + val comparisons = offsets.zip(otherComposite.offsets).map { + case (Some(a), Some(b)) => a compareTo b + case (None, None) => 0 + case (None, _) => -1 + case (_, None) => 1 + } + val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet + nonZeroSigns.size match { + case 0 => 0 // if both empty or only 0s + case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) + case _ => // there are both 1s and -1s + throw new IllegalArgumentException( + s"Invalid comparison between non-linear histories: $this <=> $other") + } + case _ => + throw new IllegalArgumentException(s"Cannot compare $this <=> $other") + } + + private def sign(num: Int): Int = num match { + case i if i < 0 => -1 + case i if i == 0 => 0 + case i if i > 0 => 1 + } +} + +object CompositeOffset { + /** + * Returns a [[CompositeOffset]] with a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(offsets: Offset*): CompositeOffset = { + CompositeOffset(offsets.map(Option(_))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala new file mode 100644 index 000000000000..008195af38b7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A simple offset for sources that produce a single linear stream of data. + */ +case class LongOffset(offset: Long) extends Offset { + + override def compareTo(other: Offset): Int = other match { + case l: LongOffset => offset.compareTo(l.offset) + case _ => + throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") + } + + def +(increment: Long): LongOffset = new LongOffset(offset + increment) + def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala new file mode 100644 index 000000000000..0f5d6445b1e2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A offset is a monotonically increasing metric used to track progress in the computation of a + * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent + * with `equals` and `hashcode`. + */ +trait Offset extends Serializable { + + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + def compareTo(other: Offset): Int + + def >(other: Offset): Boolean = compareTo(other) > 0 + def <(other: Offset): Boolean = compareTo(other) < 0 + def <=(other: Offset): Boolean = compareTo(other) <= 0 + def >=(other: Offset): Boolean = compareTo(other) >= 0 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala new file mode 100644 index 000000000000..1bd71b6b02ea --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * An interface for systems that can collect the results of a streaming query. + * + * When new data is produced by a query, a [[Sink]] must be able to transactionally collect the + * data and update the [[Offset]]. In the case of a failure, the sink will be recreated + * and must be able to return the [[Offset]] for all of the data that is made durable. + * This contract allows Spark to process data with exactly-once semantics, even in the case + * of failures that require the computation to be restarted. + */ +trait Sink { + /** + * Returns the [[Offset]] for all data that is currently present in the sink, if any. This + * function will be called by Spark when restarting execution in order to determine at which point + * in the input stream computation should be resumed from. + */ + def currentOffset: Option[Offset] + + /** + * Accepts a new batch of data as well as a [[Offset]] that denotes how far in the input + * data computation has progressed to. When computation restarts after a failure, it is important + * that a [[Sink]] returns the same [[Offset]] as the most recent batch of data that + * has been persisted durrably. Note that this does not necessarily have to be the + * [[Offset]] for the most recent batch of data that was given to the sink. For example, + * it is valid to buffer data before persisting, as long as the [[Offset]] is stored + * transactionally as data is eventually persisted. + */ + def addBatch(batch: Batch): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala new file mode 100644 index 000000000000..25922979ac83 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.types.StructType + +/** + * A source of continually arriving data for a streaming query. A [[Source]] must have a + * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark + * will regularly query each [[Source]] to see if any more data is available. + */ +trait Source { + + /** Returns the schema of the data from this source */ + def schema: StructType + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + def getNextBatch(start: Option[Offset]): Option[Batch] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala new file mode 100644 index 000000000000..ebebb829710b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Logging +import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. + * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any + * [[Source]] present in the query plan. Whenever new data arrives, a [[QueryExecution]] is created + * and the results are committed transactionally to the given [[Sink]]. + */ +class StreamExecution( + sqlContext: SQLContext, + private[sql] val logicalPlan: LogicalPlan, + val sink: Sink) extends ContinuousQuery with Logging { + + /** An monitor used to wait/notify when batches complete. */ + private val awaitBatchLock = new Object + + @volatile + private var batchRun = false + + /** Minimum amount of time in between the start of each batch. */ + private val minBatchTime = 10 + + /** Tracks how much data we have processed from each input source. */ + private[sql] val streamProgress = new StreamProgress + + /** All stream sources present the query plan. */ + private val sources = + logicalPlan.collect { case s: StreamingRelation => s.source } + + // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data + // that we have already processed). + { + sink.currentOffset match { + case Some(c: CompositeOffset) => + val storedProgress = c.offsets + val sources = logicalPlan collect { + case StreamingRelation(source, _) => source + } + + assert(sources.size == storedProgress.size) + sources.zip(storedProgress).foreach { case (source, offset) => + offset.foreach(streamProgress.update(source, _)) + } + case None => // We are starting this stream for the first time. + case _ => throw new IllegalArgumentException("Expected composite offset from sink") + } + } + + logInfo(s"Stream running at $streamProgress") + + /** When false, signals to the microBatchThread that it should stop running. */ + @volatile private var shouldRun = true + + /** The thread that runs the micro-batches of this stream. */ + private[sql] val microBatchThread = new Thread("stream execution thread") { + override def run(): Unit = { + SQLContext.setActive(sqlContext) + while (shouldRun) { + attemptBatch() + Thread.sleep(minBatchTime) // TODO: Could be tighter + } + } + } + microBatchThread.setDaemon(true) + microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + } + }) + microBatchThread.start() + + @volatile + private[sql] var lastExecution: QueryExecution = null + @volatile + private[sql] var streamDeathCause: Throwable = null + + /** + * Checks to see if any new data is present in any of the sources. When new data is available, + * a batch is executed and passed to the sink, updating the currentOffsets. + */ + private def attemptBatch(): Unit = { + val startTime = System.nanoTime() + + // A list of offsets that need to be updated if this batch is successful. + // Populated while walking the tree. + val newOffsets = new ArrayBuffer[(Source, Offset)] + // A list of attributes that will need to be updated. + var replacements = new ArrayBuffer[(Attribute, Attribute)] + // Replace sources in the logical plan with data that has arrived since the last batch. + val withNewSources = logicalPlan transform { + case StreamingRelation(source, output) => + val prevOffset = streamProgress.get(source) + val newBatch = source.getNextBatch(prevOffset) + + newBatch.map { batch => + newOffsets += ((source, batch.end)) + val newPlan = batch.data.logicalPlan + + assert(output.size == newPlan.output.size) + replacements ++= output.zip(newPlan.output) + newPlan + }.getOrElse { + LocalRelation(output) + } + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val newPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => replacementMap(a) + } + + if (newOffsets.nonEmpty) { + val optimizerStart = System.nanoTime() + + lastExecution = new QueryExecution(sqlContext, newPlan) + val executedPlan = lastExecution.executedPlan + val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 + logDebug(s"Optimized batch in ${optimizerTime}ms") + + streamProgress.synchronized { + // Update the offsets and calculate a new composite offset + newOffsets.foreach(streamProgress.update) + val newStreamProgress = logicalPlan.collect { + case StreamingRelation(source, _) => streamProgress.get(source) + } + val batchOffset = CompositeOffset(newStreamProgress) + + // Construct the batch and send it to the sink. + val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan)) + sink.addBatch(nextBatch) + } + + batchRun = true + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() + } + + val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 + logInfo(s"Compete up to $newOffsets in ${batchTime}ms") + } + + logDebug(s"Waiting for data, current: $streamProgress") + } + + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + def stop(): Unit = { + shouldRun = false + if (microBatchThread.isAlive) { microBatchThread.join() } + } + + /** + * Blocks the current thread until processing for data from the given `source` has reached at + * least the given `Offset`. This method is indented for use primarily when writing tests. + */ + def awaitOffset(source: Source, newOffset: Offset): Unit = { + def notDone = streamProgress.synchronized { + !streamProgress.contains(source) || streamProgress(source) < newOffset + } + + while (notDone) { + logInfo(s"Waiting until $newOffset at $source") + awaitBatchLock.synchronized { awaitBatchLock.wait(100) } + } + logDebug(s"Unblocked at $newOffset for $source") + } + + override def toString: String = + s""" + |=== Streaming Query === + |CurrentOffsets: $streamProgress + |Thread State: ${microBatchThread.getState} + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |$logicalPlan + """.stripMargin +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala new file mode 100644 index 000000000000..0ded1d7152c1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +/** + * A helper class that looks like a Map[Source, Offset]. + */ +class StreamProgress { + private val currentOffsets = new mutable.HashMap[Source, Offset] + + private[streaming] def update(source: Source, newOffset: Offset): Unit = { + currentOffsets.get(source).foreach(old => + assert(newOffset > old, s"Stream going backwards $newOffset -> $old")) + currentOffsets.put(source, newOffset) + } + + private[streaming] def update(newOffset: (Source, Offset)): Unit = + update(newOffset._1, newOffset._2) + + private[streaming] def apply(source: Source): Offset = currentOffsets(source) + private[streaming] def get(source: Source): Option[Offset] = currentOffsets.get(source) + private[streaming] def contains(source: Source): Boolean = currentOffsets.contains(source) + + private[streaming] def ++(updates: Map[Source, Offset]): StreamProgress = { + val updated = new StreamProgress + currentOffsets.foreach(updated.update) + updates.foreach(updated.update) + updated + } + + /** + * Used to create a new copy of this [[StreamProgress]]. While this class is currently mutable, + * it should be copied before being passed to user code. + */ + private[streaming] def copy(): StreamProgress = { + val copied = new StreamProgress + currentOffsets.foreach(copied.update) + copied + } + + override def toString: String = + currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") + + override def equals(other: Any): Boolean = other match { + case s: StreamProgress => currentOffsets == s.currentOffsets + case _ => false + } + + override def hashCode: Int = currentOffsets.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala new file mode 100644 index 000000000000..e35c444348f4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode + +object StreamingRelation { + def apply(source: Source): StreamingRelation = + StreamingRelation(source, source.schema.toAttributes) +} + +/** + * Used to link a streaming [[Source]] of data into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. + */ +case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode { + override def toString: String = source.toString +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala new file mode 100644 index 000000000000..e6a0842936ea --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} +import org.apache.spark.sql.types.StructType + +object MemoryStream { + protected val currentBlockId = new AtomicInteger(0) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] + * is primarily intended for use in unit tests as it can only replay data when the object is still + * available. + */ +case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends Source with Logging { + protected val encoder = encoderFor[A] + protected val logicalPlan = StreamingRelation(this) + protected val output = logicalPlan.output + protected val batches = new ArrayBuffer[Dataset[A]] + protected var currentOffset: LongOffset = new LongOffset(-1) + + protected def blockManager = SparkEnv.get.blockManager + + def schema: StructType = encoder.schema + + def getCurrentOffset: Offset = currentOffset + + def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { + new Dataset(sqlContext, logicalPlan) + } + + def toDF()(implicit sqlContext: SQLContext): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } + + def addData(data: TraversableOnce[A]): Offset = { + import sqlContext.implicits._ + this.synchronized { + currentOffset = currentOffset + 1 + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") + batches.append(ds) + currentOffset + } + } + + override def getNextBatch(start: Option[Offset]): Option[Batch] = synchronized { + val newBlocks = + batches.drop( + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1) + + if (newBlocks.nonEmpty) { + logDebug(s"Running [$start, $currentOffset] on blocks ${newBlocks.mkString(", ")}") + val df = newBlocks + .map(_.toDF()) + .reduceOption(_ unionAll _) + .getOrElse(sqlContext.emptyDataFrame) + + Some(new Batch(currentOffset, df)) + } else { + None + } + } + + override def toString: String = s"MemoryStream[${output.mkString(",")}]" +} + +/** + * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit + * tests and does not provide durablility. + */ +class MemorySink(schema: StructType) extends Sink with Logging { + /** An order list of batches that have been written to this [[Sink]]. */ + private var batches = new ArrayBuffer[Batch]() + + /** Used to convert an [[InternalRow]] to an external [[Row]] for comparison in testing. */ + private val externalRowConverter = RowEncoder(schema) + + override def currentOffset: Option[Offset] = synchronized { + batches.lastOption.map(_.end) + } + + override def addBatch(nextBatch: Batch): Unit = synchronized { + batches.append(nextBatch) + } + + /** Returns all rows that are stored in this [[Sink]]. */ + def allData: Seq[Row] = synchronized { + batches + .map(_.data) + .reduceOption(_ unionAll _) + .map(_.collect().toSeq) + .getOrElse(Seq.empty) + } + + /** + * Atomically drops the most recent `num` batches and resets the [[StreamProgress]] to the + * corresponding point in the input. This function can be used when testing to simulate data + * that has been lost due to buffering. + */ + def dropBatches(num: Int): Unit = synchronized { + batches.dropRight(num) + } + + override def toString: String = synchronized { + batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n") + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8911ad370aa7..299fc6efbb04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -123,6 +124,26 @@ trait SchemaRelationProvider { schema: StructType): BaseRelation } +/** + * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. + */ +trait StreamSourceProvider { + def createSource( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): Source +} + +/** + * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. + */ +trait StreamSinkProvider { + def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink +} + /** * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ce12f788b786..405e5891ac97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -304,27 +304,7 @@ object QueryTest { def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case d: java.math.BigDecimal => BigDecimal(d) - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - case o => o - }) - } - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } val sparkAnswer = try df.collect().toSeq catch { case e: Exception => val errorMessage = @@ -338,22 +318,56 @@ object QueryTest { return Some(errorMessage) } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => s""" |Results do not match for query: |${df.queryExecution} |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + |$results + """.stripMargin } + } + + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } - return None + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val errorMessage = + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + None } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala new file mode 100644 index 000000000000..f45abbf2496a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.streaming._ + +/** + * A framework for implementing tests for streaming queries and sources. + * + * A test consists of a set of steps (expressed as a `StreamAction`) that are executed in order, + * blocking as necessary to let the stream catch up. For example, the following adds some data to + * a stream, blocking until it can verify that the correct values are eventually produced. + * + * {{{ + * val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + CheckAnswer(2, 3, 4)) + * }}} + * + * Note that while we do sleep to allow the other thread to progress without spinning, + * `StreamAction` checks should not depend on the amount of time spent sleeping. Instead they + * should check the actual progress of the stream before verifying the required test condition. + * + * Currently it is assumed that all streaming queries will eventually complete in 10 seconds to + * avoid hanging forever in the case of failures. However, individual suites can change this + * by overriding `streamingTimeout`. + */ +trait StreamTest extends QueryTest with Timeouts { + + implicit class RichSource(s: Source) { + def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + } + + /** How long to wait for an active stream to catch up when checking a result. */ + val streamingTimout = 10.seconds + + /** A trait for actions that can be performed while testing a streaming DataFrame. */ + trait StreamAction + + /** A trait to mark actions that require the stream to be actively running. */ + trait StreamMustBeRunning + + /** + * Adds the given data to the stream. Subsuquent check answers will block until this data has + * been processed. + */ + object AddData { + def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + AddDataMemory(source, data) + } + + /** A trait that can be extended when testing other sources. */ + trait AddData extends StreamAction { + def source: Source + + /** + * Called to trigger adding the data. Should return the offset that will denote when this + * new data has been processed. + */ + def addData(): Offset + } + + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + override def toString: String = s"AddData to $source: ${data.mkString(",")}" + + override def addData(): Offset = { + source.addData(data) + } + } + + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks untill all added data has been processed. + */ + object CheckAnswer { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + } + + case class DropBatches(num: Int) extends StreamAction + + /** Stops the stream. It must currently be running. */ + case object StopStream extends StreamAction with StreamMustBeRunning + + /** Starts the stream, resuming if data has already been processed. It must not be running. */ + case object StartStream extends StreamAction + + /** Signals that a failure is expected and should not kill the test. */ + case object ExpectFailure extends StreamAction + + /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ + def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = + testStream(stream.toDF())(actions: _*) + + /** + * Executes the specified actions on the the given streaming DataFrame and provides helpful + * error messages in the case of failures or incorrect answers. + * + * Note that if the stream is not explictly started before an action that requires it to be + * running then it will be automatically started before performing any other actions. + */ + def testStream(stream: DataFrame)(actions: StreamAction*): Unit = { + var pos = 0 + var currentPlan: LogicalPlan = stream.logicalPlan + var currentStream: StreamExecution = null + val awaiting = new mutable.HashMap[Source, Offset]() + val sink = new MemorySink(stream.schema) + + @volatile + var streamDeathCause: Throwable = null + + // If the test doesn't manually start the stream, we do it automatically at the beginning. + val startedManually = + actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).contains(StartStream) + val startedTest = if (startedManually) actions else StartStream +: actions + + def testActions = actions.zipWithIndex.map { + case (a, i) => + if ((pos == i && startedManually) || (pos == (i + 1) && !startedManually)) { + "=> " + a.toString + } else { + " " + a.toString + } + }.mkString("\n") + + def currentOffsets = + if (currentStream != null) currentStream.streamProgress.toString else "not started" + + def threadState = + if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def testState = + s""" + |== Progress == + |$testActions + | + |== Stream == + |Stream state: $currentOffsets + |Thread state: $threadState + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |== Sink == + |$sink + | + |== Plan == + |${if (currentStream != null) currentStream.lastExecution else ""} + """ + + def checkState(check: Boolean, error: String) = if (!check) { + fail( + s""" + |Invalid State: $error + |$testState + """.stripMargin) + } + + val testThread = Thread.currentThread() + + try { + startedTest.foreach { action => + action match { + case StartStream => + checkState(currentStream == null, "stream already running") + + currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink) + currentStream.microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + testThread.interrupt() + } + }) + + case StopStream => + checkState(currentStream != null, "can not stop a stream that is not running") + currentStream.stop() + currentStream = null + + case DropBatches(num) => + checkState(currentStream == null, "dropping batches while running leads to corruption") + sink.dropBatches(num) + + case ExpectFailure => + try failAfter(streamingTimout) { + while (streamDeathCause == null) { + Thread.sleep(100) + } + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + fail( + s""" + |Timed out while waiting for failure. + |$testState + """.stripMargin) + } + + currentStream = null + streamDeathCause = null + + case a: AddData => + awaiting.put(a.source, a.addData()) + + case CheckAnswerRows(expectedAnswer) => + checkState(currentStream != null, "stream not running") + + // Block until all data added has been processed + awaiting.foreach { case (source, offset) => + failAfter(streamingTimout) { + currentStream.awaitOffset(source, offset) + } + } + + val allData = try sink.allData catch { + case e: Exception => + fail( + s""" + |Exception while getting data from sink $e + |$testState + """.stripMargin) + } + + QueryTest.sameRows(expectedAnswer, allData).foreach { + error => fail( + s""" + |$error + |$testState + """.stripMargin) + } + } + pos += 1 + } + } catch { + case _: InterruptedException if streamDeathCause != null => + fail( + s""" + |Stream Thread Died + |$testState + """.stripMargin) + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + fail( + s""" + |Timed out waiting for stream + |$testState + """.stripMargin) + } finally { + if (currentStream != null && currentStream.microBatchThread.isAlive) { + currentStream.stop() + } + } + } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. + * @param addData and add data action that adds the given numbers to the stream, encoding them + * as needed + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + implicit val intEncoder = ExpressionEncoder[Int]() + var dataPos = 0 + var running = true + val actions = new ArrayBuffer[StreamAction]() + + def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } + + def addRandomData() = { + val numItems = Random.nextInt(10) + val data = dataPos until (dataPos + numItems) + dataPos += numItems + actions += addData(data) + } + + (1 to iterations).foreach { i => + val rand = Random.nextDouble() + if(!running) { + rand match { + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StartStream + running = true + } + } else { + rand match { + case r if r < 0.1 => + addCheck() + + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StopStream + running = false + } + } + } + if(!running) { actions += StartStream } + addCheck() + testStream(ds)(actions: _*) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala new file mode 100644 index 000000000000..1dab6ebf1bee --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.test + +import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest} +import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source} +import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +object LastOptions { + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var partitionColumns: Seq[String] = Nil +} + +/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ +class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + override def createSource( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): Source = { + LastOptions.parameters = parameters + LastOptions.schema = schema + new Source { + override def getNextBatch(start: Option[Offset]): Option[Batch] = None + override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + } + } + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink = { + LastOptions.parameters = parameters + LastOptions.partitionColumns = partitionColumns + new Sink { + override def addBatch(batch: Batch): Unit = {} + override def currentOffset: Option[Offset] = None + } + } +} + +class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("resolve default source") { + sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open() + .streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + } + + test("resolve full class") { + sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test.DefaultSource") + .open() + .streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .open() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.parameters = null + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .start() + .stop() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("partitioning") { + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open() + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + assert(LastOptions.partitionColumns == Nil) + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("a") + .start() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + + + withSQLConf("spark.sql.caseSensitive" -> "false") { + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("A") + .start() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + } + + intercept[AnalysisException] { + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("b") + .start() + .stop() + } + } + + test("stream paths") { + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open("/test") + + assert(LastOptions.parameters("path") == "/test") + + LastOptions.parameters = null + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .start("/test") + .stop() + + assert(LastOptions.parameters("path") == "/test") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala new file mode 100644 index 000000000000..81760d2aa820 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class MemorySourceStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("memory stress test") { + val input = MemoryStream[Int] + val mapped = input.toDS().map(_ + 1) + + runStressTest(mapped, AddData(input, _: _*)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala new file mode 100644 index 000000000000..989465826d54 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, Offset} + +trait OffsetSuite extends SparkFunSuite { + /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ + def compare(one: Offset, two: Offset): Unit = { + test(s"comparision $one <=> $two") { + assert(one < two) + assert(one <= two) + assert(one <= one) + assert(two > one) + assert(two >= one) + assert(one >= one) + assert(one == one) + assert(two == two) + assert(one != two) + assert(two != one) + } + } + + /** Creates test to check that non-equality comparisons throw exception. */ + def compareInvalid(one: Offset, two: Offset): Unit = { + test(s"invalid comparison $one <=> $two") { + intercept[IllegalArgumentException] { + assert(one < two) + } + + intercept[IllegalArgumentException] { + assert(one <= two) + } + + intercept[IllegalArgumentException] { + assert(one > two) + } + + intercept[IllegalArgumentException] { + assert(one >= two) + } + + assert(!(one == two)) + assert(!(two == one)) + assert(one != two) + assert(two != one) + } + } +} + +class LongOffsetSuite extends OffsetSuite { + val one = LongOffset(1) + val two = LongOffset(2) + compare(one, two) +} + +class CompositeOffsetSuite extends OffsetSuite { + compare( + one = CompositeOffset(Some(LongOffset(1)) :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset(None :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compareInvalid( // sizes must be same + one = CompositeOffset(Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compare( + one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compareInvalid( + one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala new file mode 100644 index 000000000000..fbb1792596b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class StreamSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("map with recovery") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + StartStream, + CheckAnswer(2, 3, 4), + StopStream, + AddData(inputData, 4, 5, 6), + StartStream, + CheckAnswer(2, 3, 4, 5, 6, 7)) + } + + test("join") { + // Make a table and ensure it will be broadcast. + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF().join(smallTable, $"value" === $"number") + + testStream(joined)( + AddData(inputData, 1, 2, 3), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two")), + AddData(inputData, 4), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) + } + + test("union two streams") { + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val unioned = inputData1.toDS().union(inputData2.toDS()) + + testStream(unioned)( + AddData(inputData1, 1, 3, 5), + CheckAnswer(1, 3, 5), + AddData(inputData2, 2, 4, 6), + CheckAnswer(1, 2, 3, 4, 5, 6), + StopStream, + AddData(inputData1, 7), + StartStream, + AddData(inputData2, 8), + CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8)) + } + + test("sql queries") { + val inputData = MemoryStream[Int] + inputData.toDF().registerTempTable("stream") + val evens = sql("SELECT * FROM stream WHERE value % 2 = 0") + + testStream(evens)( + AddData(inputData, 1, 2, 3, 4), + CheckAnswer(2, 4)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e7b376548787..c341191c70bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -36,7 +36,7 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. From 29d92181d0c49988c387d34e4a71b1afe02c29e2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 10:15:40 -0800 Subject: [PATCH 095/131] [SPARK-13094][SQL] Add encoders for seq/array of primitives Author: Michael Armbrust Closes #11014 from marmbrus/seqEncoders. --- .../org/apache/spark/sql/SQLImplicits.scala | 63 ++++++++++++++++++- .../spark/sql/DatasetPrimitiveSuite.scala | 22 +++++++ .../org/apache/spark/sql/QueryTest.scala | 8 ++- 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index ab414799f1a4..16c4095db722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -39,6 +39,8 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + // Primitives + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() @@ -56,13 +58,72 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() - /** @since 1.6.0 */ + /** @since 1.6.0 */ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + // Seqs + + /** @since 1.6.1 */ + implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + + // Arrays + + /** @since 1.6.1 */ + implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = + ExpressionEncoder() + /** * Creates a [[Dataset]] from an RDD. * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index f75d0961823c..243d13b19d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { agged, "1", "abc", "3", "xyz", "5", "hello") } + + test("Arrays and Lists") { + checkAnswer(Seq(Seq(1)).toDS(), Seq(1)) + checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkAnswer(Seq(Seq(true)).toDS(), Seq(true)) + checkAnswer(Seq(Seq("test")).toDS(), Seq("test")) + checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkAnswer(Seq(Array(1)).toDS(), Array(1)) + checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkAnswer(Seq(Array(true)).toDS(), Array(true)) + checkAnswer(Seq(Array("test")).toDS(), Array("test")) + checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 405e5891ac97..5401212428d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - if (decoded != expectedAnswer.toSet) { + // Handle the case where the return type is an array + val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) + def normalEquality = decoded == expectedAnswer.toSet + def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet + def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) + + if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted From b93830126cc59a26e2cfb5d7b3c17f9cfbf85988 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Feb 2016 10:41:06 -0800 Subject: [PATCH 096/131] [SPARK-13114][SQL] Add a test for tokens more than the fields in schema https://issues.apache.org/jira/browse/SPARK-13114 This PR adds a test for tokens more than the fields in schema. Author: hyukjinkwon Closes #11020 from HyukjinKwon/SPARK-13114. --- sql/core/src/test/resources/cars-malformed.csv | 6 ++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 sql/core/src/test/resources/cars-malformed.csv diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/cars-malformed.csv new file mode 100644 index 000000000000..cfa378c01f1d --- /dev/null +++ b/sql/core/src/test/resources/cars-malformed.csv @@ -0,0 +1,6 @@ +~ All the rows here are malformed having tokens more than the schema (header). +year,make,model,comment,blank +"2012","Tesla","S","No comment",,null,null + +1997,Ford,E350,"Go get one now they are going fast",,null,null +2015,Chevy,,,, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a79566b1f365..fa4f137b703b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.types._ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val carsFile = "cars.csv" + private val carsMalformedFile = "cars-malformed.csv" private val carsFile8859 = "cars_iso-8859-1.csv" private val carsTsvFile = "cars.tsv" private val carsAltFile = "cars-alternative.csv" @@ -191,6 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) } + test("test for tokens more than the fields in the schema") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .option("comment", "~") + .load(testFile(carsMalformedFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + test("test with null quote character") { val cars = sqlContext.read .format("csv") From cba1d6b659288bfcd8db83a6d778155bab2bbecf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 2 Feb 2016 10:50:22 -0800 Subject: [PATCH 097/131] [SPARK-12631][PYSPARK][DOC] PySpark clustering parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the clustering module. Author: Bryan Cutler Closes #10610 from BryanCutler/param-desc-consistent-cluster-SPARK-12631. --- .../mllib/clustering/GaussianMixture.scala | 12 +- .../spark/mllib/clustering/KMeans.scala | 31 +- .../apache/spark/mllib/clustering/LDA.scala | 13 +- .../clustering/PowerIterationClustering.scala | 4 +- .../mllib/clustering/StreamingKMeans.scala | 6 +- python/pyspark/mllib/clustering.py | 265 +++++++++++++----- 6 files changed, 228 insertions(+), 103 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 7b203e2f4081..88dbfe3fcc9f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -45,10 +45,10 @@ import org.apache.spark.util.Utils * This is due to high-dimensional data (a) making it difficult to cluster at all (based * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. * - * @param k The number of independent Gaussians in the mixture model - * @param convergenceTol The maximum change in log-likelihood at which convergence - * is considered to have occurred. - * @param maxIterations The maximum number of iterations to perform + * @param k Number of independent Gaussians in the mixture model. + * @param convergenceTol Maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations Maximum number of iterations allowed. */ @Since("1.3.0") class GaussianMixture private ( @@ -108,7 +108,7 @@ class GaussianMixture private ( def getK: Int = k /** - * Set the maximum number of iterations to run. Default: 100 + * Set the maximum number of iterations allowed. Default: 100 */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { @@ -117,7 +117,7 @@ class GaussianMixture private ( } /** - * Return the maximum number of iterations to run + * Return the maximum number of iterations allowed */ @Since("1.3.0") def getMaxIterations: Int = maxIterations diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ca11ede4ccd4..901164a39117 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -70,13 +70,13 @@ class KMeans private ( } /** - * Maximum number of iterations to run. + * Maximum number of iterations allowed. */ @Since("1.4.0") def getMaxIterations: Int = maxIterations /** - * Set maximum number of iterations to run. Default: 20. + * Set maximum number of iterations allowed. Default: 20. */ @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { @@ -482,12 +482,15 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). - * @param seed random seed value for cluster initialization + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + * @param seed Random seed for cluster initialization. Default is to generate seed based + * on system time. */ @Since("1.3.0") def train( @@ -508,11 +511,13 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") */ @Since("0.8.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index eb802a365ed6..81566b4779d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -61,14 +61,13 @@ class LDA private ( ldaOptimizer = new EMLDAOptimizer) /** - * Number of topics to infer. I.e., the number of soft cluster centers. - * + * Number of topics to infer, i.e., the number of soft cluster centers. */ @Since("1.3.0") def getK: Int = k /** - * Number of topics to infer. I.e., the number of soft cluster centers. + * Set the number of topics to infer, i.e., the number of soft cluster centers. * (default = 10) */ @Since("1.3.0") @@ -222,13 +221,13 @@ class LDA private ( def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** - * Maximum number of iterations for learning. + * Maximum number of iterations allowed. */ @Since("1.3.0") def getMaxIterations: Int = maxIterations /** - * Maximum number of iterations for learning. + * Set the maximum number of iterations allowed. * (default = 20) */ @Since("1.3.0") @@ -238,13 +237,13 @@ class LDA private ( } /** - * Random seed + * Random seed for cluster initialization. */ @Since("1.3.0") def getSeed: Long = seed /** - * Random seed + * Set the random seed for cluster initialization. */ @Since("1.3.0") def setSeed(seed: Long): this.type = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 2ab0920b0636..1ab7cb393b08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -111,7 +111,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. - * @param initMode Initialization mode. + * @param initMode Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. + * Default: random. * * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 79d217e183c6..d99b89dc49eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -183,7 +183,7 @@ class StreamingKMeans @Since("1.2.0") ( } /** - * Set the decay factor directly (for forgetful algorithms). + * Set the forgetfulness of the previous centroids. */ @Since("1.2.0") def setDecayFactor(a: Double): this.type = { @@ -192,7 +192,9 @@ class StreamingKMeans @Since("1.2.0") ( } /** - * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + * Set the half life and time unit ("batches" or "points"). If points, then the decay factor + * is raised to the power of number of new points and if batches, then decay factor will be + * used as is. */ @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 4e9eb96fd9da..ad04e46e8870 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -88,8 +88,11 @@ def predict(self, x): Find the cluster that each of the points belongs to in this model. - :param x: the point (or RDD of points) to determine - compute the clusters for. + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -105,7 +108,8 @@ def computeCost(self, x): points to their nearest center) for this model on the given data. If provided with an RDD of points returns the sum. - :param point: the point or RDD of points to compute the cost(s). + :param point: + A data point (or RDD of points) to compute the cost(s). """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -143,17 +147,23 @@ def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1 """ Runs the bisecting k-means algorithm return the model. - :param rdd: input RDD to be trained on - :param k: The desired number of leaf clusters (default: 4). - The actual number could be smaller if there are no divisible - leaf clusters. - :param maxIterations: the max number of k-means iterations to - split clusters (default: 20) - :param minDivisibleClusterSize: the minimum number of points - (if >= 1.0) or the minimum proportion of points (if < 1.0) - of a divisible cluster (default: 1) - :param seed: a random seed (default: -1888008604 from - classOf[BisectingKMeans].getName.##) + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + The desired number of leaf clusters. The actual number could + be smaller if there are no divisible leaf clusters. + (default: 4) + :param maxIterations: + Maximum number of iterations allowed to split clusters. + (default: 20) + :param minDivisibleClusterSize: + Minimum number of points (if >= 1.0) or the minimum proportion + of points (if < 1.0) of a divisible cluster. + (default: 1) + :param seed: + Random seed value for cluster initialization. + (default: -1888008604 from classOf[BisectingKMeans].getName.##) """ java_model = callMLlibFunc( "trainBisectingKMeans", rdd.map(_convert_to_vector), @@ -239,8 +249,11 @@ def predict(self, x): Find the cluster that each of the points belongs to in this model. - :param x: the point (or RDD of points) to determine - compute the clusters for. + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ best = 0 best_distance = float("inf") @@ -262,7 +275,8 @@ def computeCost(self, rdd): their nearest center) for this model on the given data. - :param point: the RDD of points to compute the cost on. + :param rdd: + The RDD of points to compute the cost on. """ cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), [_convert_to_vector(c) for c in self.centers]) @@ -296,7 +310,44 @@ class KMeans(object): @since('0.9.0') def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): - """Train a k-means clustering model.""" + """ + Train a k-means clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of clusters to create. + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param runs: + Number of runs to execute in parallel. The best model according + to the cost function will be returned (deprecated in 1.6.0). + (default: 1) + :param initializationMode: + The initialization algorithm. This can be either "random" or + "k-means||". + (default: "k-means||") + :param seed: + Random seed value for cluster initialization. Set as None to + generate seed based on system time. + (default: None) + :param initializationSteps: + Number of steps for the k-means|| initialization mode. + This is an advanced setting -- the default of 5 is almost + always enough. + (default: 5) + :param epsilon: + Distance threshold within which a center will be considered to + have converged. If all centers move less than this Euclidean + distance, iterations are stopped. + (default: 1e-4) + :param initialModel: + Initial cluster centers can be provided as a KMeansModel object + rather than using the random or k-means|| initializationModel. + (default: None) + """ if runs != 1: warnings.warn( "Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.") @@ -415,8 +466,11 @@ def predict(self, x): Find the cluster to which the point 'x' or each point in RDD 'x' has maximum membership in this model. - :param x: vector or RDD of vector represents data points. - :return: cluster label or RDD of cluster labels. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + Predicted cluster label or an RDD of predicted cluster labels + if the input is an RDD. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) @@ -430,9 +484,11 @@ def predictSoft(self, x): """ Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: vector or RDD of vector represents data points. - :return: the membership value to all mixture components for vector 'x' - or each vector in RDD 'x'. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + The membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -447,8 +503,10 @@ def predictSoft(self, x): def load(cls, sc, path): """Load the GaussianMixtureModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ model = cls._load_java(sc, path) wrapper = sc._jvm.GaussianMixtureModelWrapper(model) @@ -461,19 +519,35 @@ class GaussianMixture(object): Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. - :param data: RDD of data points - :param k: Number of components - :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 - :param maxIterations: Number of iterations. Default to 100 - :param seed: Random Seed - :param initialModel: GaussianMixtureModel for initializing learning - .. versionadded:: 1.3.0 """ @classmethod @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): - """Train a Gaussian Mixture clustering model.""" + """ + Train a Gaussian Mixture clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of independent Gaussians in the mixture model. + :param convergenceTol: + Maximum change in log-likelihood at which convergence is + considered to have occurred. + (default: 1e-3) + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param seed: + Random seed for initial Gaussian distribution. Set as None to + generate seed based on system time. + (default: None) + :param initialModel: + Initial GMM starting point, bypassing the random + initialization. + (default: None) + """ initialModelWeights = None initialModelMu = None initialModelSigma = None @@ -574,18 +648,24 @@ class PowerIterationClustering(object): @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): """ - :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the - affinity matrix, which is the matrix A in the PIC paper. - The similarity s,,ij,, must be nonnegative. - This is a symmetric matrix and hence s,,ij,, = s,,ji,,. - For any (i, j) with nonzero similarity, there should be - either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. - Tuples with i = j are ignored, because we assume - s,,ij,, = 0.0. - :param k: Number of clusters. - :param maxIterations: Maximum number of iterations of the - PIC algorithm. - :param initMode: Initialization mode. + :param rdd: + An RDD of (i, j, s\ :sub:`ij`\) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. The + similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric + matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with + nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or + (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored, + because it is assumed s\ :sub:`ij`\ = 0.0. + :param k: + Number of clusters. + :param maxIterations: + Maximum number of iterations of the PIC algorithm. + (default: 100) + :param initMode: + Initialization mode. This can be either "random" to use + a random vector as vertex properties, or "degree" to use + normalized sum similarities. + (default: "random") """ model = callMLlibFunc("trainPowerIterationClusteringModel", rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) @@ -625,8 +705,10 @@ class StreamingKMeansModel(KMeansModel): and new data. If it set to zero, the old centroids are completely forgotten. - :param clusterCenters: Initial cluster centers. - :param clusterWeights: List of weights assigned to each cluster. + :param clusterCenters: + Initial cluster centers. + :param clusterWeights: + List of weights assigned to each cluster. >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] >>> initWeights = [1.0, 1.0] @@ -673,11 +755,14 @@ def clusterWeights(self): def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data - :param data: Should be a RDD that represents the new data. - :param decayFactor: forgetfulness of the previous centroids. - :param timeUnit: Can be "batches" or "points". If points, then the - decay factor is raised to the power of number of new - points and if batches, it is used as it is. + :param data: + RDD with new data for the model update. + :param decayFactor: + Forgetfulness of the previous centroids. + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor + is raised to the power of number of new points and if batches, + then decay factor will be used as is. """ if not isinstance(data, RDD): raise TypeError("Data should be of an RDD, got %s." % type(data)) @@ -704,10 +789,17 @@ class StreamingKMeans(object): More details on how the centroids are updated are provided under the docs of StreamingKMeansModel. - :param k: int, number of clusters - :param decayFactor: float, forgetfulness of the previous centroids. - :param timeUnit: can be "batches" or "points". If points, then the - decayfactor is raised to the power of no. of new points. + :param k: + Number of clusters. + (default: 2) + :param decayFactor: + Forgetfulness of the previous centroids. + (default: 1.0) + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor is + raised to the power of number of new points and if batches, then + decay factor will be used as is. + (default: "batches") .. versionadded:: 1.5.0 """ @@ -870,11 +962,13 @@ def describeTopics(self, maxTermsPerTopic=None): WARNING: If vocabSize and k are large, this can return a large object! - :param maxTermsPerTopic: Maximum number of terms to collect for each topic. - (default: vocabulary size) - :return: Array over topics. Each topic is represented as a pair of matching arrays: - (term indices, term weights in topic). - Each topic's terms are sorted in order of decreasing weight. + :param maxTermsPerTopic: + Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: + Array over topics. Each topic is represented as a pair of + matching arrays: (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ if maxTermsPerTopic is None: topics = self.call("describeTopics") @@ -887,8 +981,10 @@ def describeTopics(self, maxTermsPerTopic=None): def load(cls, sc, path): """Load the LDAModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) @@ -909,17 +1005,38 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. - :param rdd: RDD of data points - :param k: Number of clusters you want - :param maxIterations: Number of iterations. Default to 20 - :param docConcentration: Concentration parameter (commonly named "alpha") - for the prior placed on documents' distributions over topics ("theta"). - :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") - for the prior placed on topics' distributions over terms. - :param seed: Random Seed - :param checkpointInterval: Period (in iterations) between checkpoints. - :param optimizer: LDAOptimizer used to perform the actual calculation. - Currently "em", "online" are supported. Default to "em". + :param rdd: + RDD of documents, which are tuples of document IDs and term + (word) count vectors. The term count vectors are "bags of + words" with a fixed-size vocabulary (where the vocabulary size + is the length of the vector). Document IDs must be unique + and >= 0. + :param k: + Number of topics to infer, i.e., the number of soft cluster + centers. + (default: 10) + :param maxIterations: + Maximum number of iterations allowed. + (default: 20) + :param docConcentration: + Concentration parameter (commonly named "alpha") for the prior + placed on documents' distributions over topics ("theta"). + (default: -1.0) + :param topicConcentration: + Concentration parameter (commonly named "beta" or "eta") for + the prior placed on topics' distributions over terms. + (default: -1.0) + :param seed: + Random seed for cluster initialization. Set as None to generate + seed based on system time. + (default: None) + :param checkpointInterval: + Period (in iterations) between checkpoints. + (default: 10) + :param optimizer: + LDAOptimizer used to perform the actual calculation. Currently + "em", "online" are supported. + (default: "em") """ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, docConcentration, topicConcentration, seed, From 358300c795025735c3b2f96c5447b1b227d4abc1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 2 Feb 2016 11:09:40 -0800 Subject: [PATCH 098/131] [SPARK-13056][SQL] map column would throw NPE if value is null Jira: https://issues.apache.org/jira/browse/SPARK-13056 Create a map like { "a": "somestring", "b": null} Query like SELECT col["b"] FROM t1; NPE would be thrown. Author: Daoyuan Wang Closes #10964 from adrian-wang/npewriter. --- .../expressions/complexTypeExtractors.scala | 15 +++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5256baaf432a..9f2f82d68cca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -218,7 +218,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { + if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) @@ -267,6 +267,7 @@ case class GetMapValue(child: Expression, key: Expression) val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() + val values = map.valueArray() var i = 0 var found = false @@ -278,10 +279,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!found) { + if (!found || values.isNullAt(i)) { null } else { - map.valueArray().get(i, dataType) + values.get(i, dataType) } } @@ -291,10 +292,12 @@ case class GetMapValue(child: Expression, key: Expression) val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") + val values = ctx.freshName("values") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); + final ArrayData $values = $eval1.valueArray(); int $index = 0; boolean $found = false; @@ -307,10 +310,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if ($found) { - ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; - } else { + if (!$found || $values.isNullAt($index)) { ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(values, dataType, index)}; } """ }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2b821c1056f5..79bfd4b44b70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2055,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-13056: Null in map value causes NPE") { + val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") + withTempTable("maptest") { + df.registerTempTable("maptest") + // local optimization will by pass codegen code, so we should keep the filter `key=1` + checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring")) + checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) + } + } + test("hash function") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") withTempTable("tbl") { From b1835d727234fdff42aa8cadd17ddcf43b0bed15 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Tue, 2 Feb 2016 11:16:24 -0800 Subject: [PATCH 099/131] [SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication Fixes problem and verifies fix by test suite. Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn and deduplicates SchemaUtils.appendColumn functions. Author: Grzegorz Chilkiewicz Closes #10741 from grzegorz-chilkiewicz/master. --- .../spark/ml/feature/StopWordsRemover.scala | 4 +--- .../org/apache/spark/ml/util/SchemaUtils.scala | 8 +++----- .../spark/ml/feature/StopWordsRemoverSuite.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index b93c9ed382bd..e53ef300f644 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index e71dd9eee03e..76021ad8f4e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -71,12 +71,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index fb217e0c1de9..a5b24c18565b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -89,4 +89,19 @@ class StopWordsRemoverSuite .setCaseSensitive(true) testDefaultReadWrite(t) } + + test("StopWordsRemover output column already exists") { + val outputCol = "expected" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol(outputCol) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("The", "the", "swift"), Seq("swift")) + )).toDF("raw", outputCol) + + val thrown = intercept[IllegalArgumentException] { + testStopWordsRemover(remover, dataSet) + } + assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + } } From 7f6e3ec79b77400f558ceffa10b2af011962115f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Feb 2016 11:29:20 -0800 Subject: [PATCH 100/131] [SPARK-13138][SQL] Add "logical" package prefix for ddl.scala ddl.scala is defined in the execution package, and yet its reference of "UnaryNode" and "Command" are logical. This was fairly confusing when I was trying to understand the ddl code. Author: Reynold Xin Closes #11021 from rxin/SPARK-13138. --- .../spark/sql/execution/datasources/ddl.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1554209be989..a141b58d3d72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ import org.apache.spark.sql.types._ */ case class DescribeCommand( table: LogicalPlan, - isExtended: Boolean) extends LogicalPlan with Command { + isExtended: Boolean) extends LogicalPlan with logical.Command { override def children: Seq[LogicalPlan] = Seq.empty @@ -59,7 +60,7 @@ case class CreateTableUsing( temporary: Boolean, options: Map[String, String], allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with Command { + managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty @@ -67,8 +68,8 @@ case class CreateTableUsing( /** * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer - * can analyze the logical plan that will be used to populate the table. + * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the + * analyzer can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ case class CreateTableUsingAsSelect( @@ -79,7 +80,7 @@ case class CreateTableUsingAsSelect( bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan) extends logical.UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } From be5dd881f1eff248224a92d57cfd1309cb3acf38 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 11:50:14 -0800 Subject: [PATCH 101/131] [SPARK-12913] [SQL] Improve performance of stat functions As benchmarked and discussed here: https://github.com/apache/spark/pull/10786/files#r50038294, benefits from codegen, the declarative aggregate function could be much faster than imperative one. Author: Davies Liu Closes #10960 from davies/stddev. --- .../catalyst/analysis/HiveTypeCoercion.scala | 18 +- .../aggregate/CentralMomentAgg.scala | 285 ++++++++---------- .../catalyst/expressions/aggregate/Corr.scala | 208 ++++--------- .../expressions/aggregate/Covariance.scala | 205 ++++--------- .../expressions/aggregate/Kurtosis.scala | 54 ---- .../expressions/aggregate/Skewness.scala | 53 ---- .../expressions/aggregate/Stddev.scala | 81 ----- .../expressions/aggregate/Variance.scala | 81 ----- .../spark/sql/catalyst/expressions/misc.scala | 18 ++ .../apache/spark/sql/execution/Window.scala | 6 +- .../aggregate/TungstenAggregate.scala | 1 - .../BenchmarkWholeStageCodegen.scala | 55 +++- .../execution/HiveCompatibilitySuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 17 +- 14 files changed, 331 insertions(+), 755 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 957ac89fa530..57bdb164e1a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -347,18 +347,12 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) + case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) + case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 30f602227b17..9d2db4514481 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -44,7 +42,7 @@ import org.apache.spark.sql.types._ * * @param child to compute central moments of. */ -abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { +abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate { /** * The central moment order to be computed. @@ -52,178 +50,161 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w protected def momentOrder: Int override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = true - override def dataType: DataType = DoubleType + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val avg = AttributeReference("avg", DoubleType, nullable = false)() + protected val m2 = AttributeReference("m2", DoubleType, nullable = false)() + protected val m3 = AttributeReference("m3", DoubleType, nullable = false)() + protected val m4 = AttributeReference("m4", DoubleType, nullable = false)() - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") + private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1) - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) - /** - * Size of aggregation buffer. - */ - private[this] val bufferSize = 5 + override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => - AttributeReference(s"M$i", DoubleType)() + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) } - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - // buffer offsets - private[this] val nOffset = mutableAggBufferOffset - private[this] val meanOffset = mutableAggBufferOffset + 1 - private[this] val secondMomentOffset = mutableAggBufferOffset + 2 - private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 - private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 - - // frequently used values for online updates - private[this] var delta = 0.0 - private[this] var deltaN = 0.0 - private[this] var delta2 = 0.0 - private[this] var deltaN2 = 0.0 - private[this] var n = 0.0 - private[this] var mean = 0.0 - private[this] var m2 = 0.0 - private[this] var m3 = 0.0 - private[this] var m4 = 0.0 + override val mergeExpressions: Seq[Expression] = { - /** - * Initialize all moments to zero. - */ - override def initialize(buffer: MutableRow): Unit = { - for (aggIndex <- 0 until bufferSize) { - buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val delta = avg.right - avg.left + val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val newAvg = avg.left + deltaN * n2 + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2 + // `m3.right` is not available if momentOrder < 3 + val newM3 = if (momentOrder >= 3) { + m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) + + Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left) + } else { + Literal(0.0) } + // `m4.right` is not available if momentOrder < 4 + val newM4 = if (momentOrder >= 4) { + m4.left + m4.right + + deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) + + Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) + + Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } +} - /** - * Update the central moments buffer. - */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = Cast(child, DoubleType).eval(input) - if (v != null) { - val updateValue = v match { - case d: Double => d - } - - n = buffer.getDouble(nOffset) - mean = buffer.getDouble(meanOffset) - - n += 1.0 - buffer.setDouble(nOffset, n) - delta = updateValue - mean - deltaN = delta / n - mean += deltaN - buffer.setDouble(meanOffset, mean) - - if (momentOrder >= 2) { - m2 = buffer.getDouble(secondMomentOffset) - m2 += delta * (delta - deltaN) - buffer.setDouble(secondMomentOffset, m2) - } - - if (momentOrder >= 3) { - delta2 = delta * delta - deltaN2 = deltaN * deltaN - m3 = buffer.getDouble(thirdMomentOffset) - m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) - buffer.setDouble(thirdMomentOffset, m3) - } - - if (momentOrder >= 4) { - m4 = buffer.getDouble(fourthMomentOffset) - m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + - delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(fourthMomentOffset, m4) - } - } +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + Sqrt(m2 / n)) } - /** - * Merge two central moment buffers. - */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val n1 = buffer1.getDouble(nOffset) - val n2 = buffer2.getDouble(inputAggBufferOffset) - val mean1 = buffer1.getDouble(meanOffset) - val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 - var secondMoment1 = 0.0 - var secondMoment2 = 0.0 + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + Sqrt(m2 / (n - Literal(1.0))))) + } - var thirdMoment1 = 0.0 - var thirdMoment2 = 0.0 + override def prettyName: String = "stddev_samp" +} - var fourthMoment1 = 0.0 - var fourthMoment2 = 0.0 +// Compute the population variance of a column +case class VariancePop(child: Expression) extends CentralMomentAgg(child) { - n = n1 + n2 - buffer1.setDouble(nOffset, n) - delta = mean2 - mean1 - deltaN = if (n == 0.0) 0.0 else delta / n - mean = mean1 + deltaN * n2 - buffer1.setDouble(mutableAggBufferOffset + 1, mean) + override protected def momentOrder = 2 - // higher order moments computed according to: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics - if (momentOrder >= 2) { - secondMoment1 = buffer1.getDouble(secondMomentOffset) - secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 - buffer1.setDouble(secondMomentOffset, m2) - } + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + m2 / n) + } - if (momentOrder >= 3) { - thirdMoment1 = buffer1.getDouble(thirdMomentOffset) - thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * - (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) - buffer1.setDouble(thirdMomentOffset, m3) - } + override def prettyName: String = "var_pop" +} - if (momentOrder >= 4) { - fourthMoment1 = buffer1.getDouble(fourthMomentOffset) - fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * - n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * - (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + - 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) - buffer1.setDouble(fourthMomentOffset, m4) - } +// Compute the sample variance of a column +case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + m2 / (n - Literal(1.0)))) } - /** - * Compute aggregate statistic from sufficient moments. - * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) - * needed to compute the aggregate stat. - */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any - - override final def eval(buffer: InternalRow): Any = { - val n = buffer.getDouble(nOffset) - val mean = buffer.getDouble(meanOffset) - val moments = Array.ofDim[Double](momentOrder + 1) - moments(0) = 1.0 - moments(1) = 0.0 - if (momentOrder >= 2) { - moments(2) = buffer.getDouble(secondMomentOffset) - } - if (momentOrder >= 3) { - moments(3) = buffer.getDouble(thirdMomentOffset) - } - if (momentOrder >= 4) { - moments(4) = buffer.getDouble(fourthMomentOffset) - } + override def prettyName: String = "var_samp" +} + +case class Skewness(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "skewness" + + override protected def momentOrder = 3 - getStatistic(n, mean, moments) + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) } } + +case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 4 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + n * m4 / (m2 * m2) - Literal(3.0))) + } + + override def prettyName: String = "kurtosis" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index d25f3335ffd9..e6b8214ef25e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -29,165 +28,70 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -case class Corr( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { - - def this(left: Expression, right: Expression) = - this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def children: Seq[Expression] = Seq(left, right) +case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true - override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"corr requires that both arguments are double type, " + - s"not (${left.dataType}, ${right.dataType}).") - } + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)() + protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk) + + override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) + + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dxN = dx / newN + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dxN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + val newXMk = xMk + dx * (x - newXAvg) + val newYMk = yMk + dy * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk), + If(isNull, xMk, newXMk), + If(isNull, yMk, newYMk) + ) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) + override val mergeExpressions: Seq[Expression] = { + + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) } - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("MkX", DoubleType)(), - AttributeReference("MkY", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 - private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 - private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 - private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 - private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 - private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 - private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 - private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 - private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(mutableAggBufferOffset, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) - buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / Sqrt(xMk * yMk))) } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer.getLong(mutableAggBufferOffsetPlus5) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - - buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(mutableAggBufferOffset) - var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer1.getLong(mutableAggBufferOffsetPlus5) - - val xAvg2 = buffer2.getDouble(inputAggBufferOffset) - val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) - val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) - val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) - val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) - - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 - MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 - count = totalCount - - buffer1.setDouble(mutableAggBufferOffset, xAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer1.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(mutableAggBufferOffsetPlus5) - if (count > 0) { - val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - val corr = Ck / math.sqrt(MkX * MkY) - if (corr.isNaN) { - null - } else { - corr - } - } else { - null - } - } + override def prettyName: String = "corr" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index f53b01be2a0d..c175a8c4c77b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -17,182 +17,79 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Compute the covariance between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. - * */ -abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate - with Serializable { - override def children: Seq[Expression] = Seq(left, right) +abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true - override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"covariance requires that both arguments are double type, " + - s"not (${left.dataType}, ${right.dataType}).") - } - } - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) - } - - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - val xAvgOffset = mutableAggBufferOffset - val yAvgOffset = mutableAggBufferOffset + 1 - val CkOffset = mutableAggBufferOffset + 2 - val countOffset = mutableAggBufferOffset + 3 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - val inputXAvgOffset = inputAggBufferOffset - val inputYAvgOffset = inputAggBufferOffset + 1 - val inputCkOffset = inputAggBufferOffset + 2 - val inputCountOffset = inputAggBufferOffset + 3 - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(xAvgOffset, 0.0) - buffer.setDouble(yAvgOffset, 0.0) - buffer.setDouble(CkOffset, 0.0) - buffer.setLong(countOffset, 0L) - } - - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(xAvgOffset) - var yAvg = buffer.getDouble(yAvgOffset) - var Ck = buffer.getDouble(CkOffset) - var count = buffer.getLong(countOffset) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - - buffer.setDouble(xAvgOffset, xAvg) - buffer.setDouble(yAvgOffset, yAvg) - buffer.setDouble(CkOffset, Ck) - buffer.setLong(countOffset, count) - } + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck) + + override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) + + override lazy val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) } - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputCountOffset) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(xAvgOffset) - var yAvg = buffer1.getDouble(yAvgOffset) - var Ck = buffer1.getDouble(CkOffset) - var count = buffer1.getLong(countOffset) + override val mergeExpressions: Seq[Expression] = { - val xAvg2 = buffer2.getDouble(inputXAvgOffset) - val yAvg2 = buffer2.getDouble(inputYAvgOffset) - val Ck2 = buffer2.getDouble(inputCkOffset) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - count = totalCount - - buffer1.setDouble(xAvgOffset, xAvg) - buffer1.setDouble(yAvgOffset, yAvg) - buffer1.setDouble(CkOffset, Ck) - buffer1.setLong(countOffset, count) - } + Seq(newN, newXAvg, newYAvg, newCk) } } -case class CovSample( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends Covariance(left, right) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(countOffset) - if (count > 1) { - val Ck = buffer.getDouble(CkOffset) - val cov = Ck / (count - 1) - if (cov.isNaN) { - null - } else { - cov - } - } else { - null - } +case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + ck / n) } + override def prettyName: String = "covar_pop" } -case class CovPopulation( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends Covariance(left, right) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(countOffset) - if (count > 0) { - val Ck = buffer.getDouble(CkOffset) - if (Ck.isNaN) { - null - } else { - Ck / count - } - } else { - null - } +case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / (n - Literal(1.0)))) } + override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala deleted file mode 100644 index c2bf2cb94116..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Kurtosis(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "kurtosis" - - override protected val momentOrder = 4 - - // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m4 = moments(4) - - if (n == 0.0) { - null - } else if (m2 == 0.0) { - Double.NaN - } else { - n * m4 / (m2 * m2) - 3.0 - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala deleted file mode 100644 index 9411bcea2539..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Skewness(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "skewness" - - override protected val momentOrder = 3 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m3 = moments(3) - - if (n == 0.0) { - null - } else if (m2 == 0.0) { - Double.NaN - } else { - math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala deleted file mode 100644 index eec79a9033e3..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class StddevSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "stddev_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else if (n == 1.0) { - Double.NaN - } else { - math.sqrt(moments(2) / (n - 1.0)) - } - } -} - -case class StddevPop( - child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "stddev_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else { - math.sqrt(moments(2) / n) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala deleted file mode 100644 index cf3a74030539..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class VarianceSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else if (n == 1.0) { - Double.NaN - } else { - moments(2) / (n - 1.0) - } - } -} - -case class VariancePop( - child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else { - moments(2) / n - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 36e1fa1176d2..f4ccadd9c563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -424,3 +424,21 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } } + +/** + * Print the result of an expression to stderr (used for debugging codegen). + */ +case class PrintToStderr(child: Expression) extends UnaryExpression { + + override def dataType: DataType = child.dataType + + protected override def nullSafeEval(input: Any): Any = input + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, c => + s""" + | System.err.println("Result of ${child.simpleString} is " + $c); + | ${ev.value} = $c; + """.stripMargin) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 26a7340f1ae1..84154a47de39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -198,7 +198,8 @@ case class Window( functions, ordinal, child.output, - (expressions, schema) => newMutableProjection(expressions, schema)) + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) // Create the factory val factory = key match { @@ -210,7 +211,8 @@ case class Window( ordinal, functions, child.output, - (expressions, schema) => newMutableProjection(expressions, schema), + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 57db7262fdaf..a8a81d6d6574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -240,7 +240,6 @@ case class TungstenAggregate( | ${bufVars(i).value} = ${ev.value}; """.stripMargin } - s""" | // do aggregate | ${aggVals.map(_.code).mkString("\n")} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2f09c8a114bc..1ccf0e3d0656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -59,6 +59,55 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testStatFunctions(values: Int): Unit = { + + val benchmark = new Benchmark("stat functions", values) + + benchmark.addCase("stddev w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("stddev w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("kurtosis w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + } + + benchmark.addCase("kurtosis w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + } + + + /** + Using ImperativeAggregate (as implemented in Spark 1.6): + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 2019.04 10.39 1.00 X + stddev w codegen 2097.29 10.00 0.96 X + kurtosis w/o codegen 2108.99 9.94 0.96 X + kurtosis w codegen 2090.69 10.03 0.97 X + + Using DeclarativeAggregate: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 989.22 21.20 1.00 X + stddev w codegen 352.35 59.52 2.81 X + kurtosis w/o codegen 3636.91 5.77 0.27 X + kurtosis w codegen 369.25 56.79 2.68 X + */ + benchmark.run() + } + def testAggregateWithKey(values: Int): Unit = { val benchmark = new Benchmark("Aggregate with keys", values) @@ -147,8 +196,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } - test("benchmark") { - // testWholeStage(1024 * 1024 * 200) + // These benchmark are skipped in normal build + ignore("benchmark") { + // testWholeStage(200 << 20) + // testStddev(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 554d47d651ae..61b73fa55714 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -325,6 +325,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_ignore_protection", "protectmode", + // Hive returns null rather than NaN when n = 1 + "udaf_covar_samp", + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. "udf_abs", "udf_format_number", @@ -881,7 +884,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_widening", "udaf_collect_set", "udaf_covar_pop", - "udaf_covar_samp", "udaf_histogram_numeric", "udf2", "udf5", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7a9ed1eaf3db..caf1db9ad085 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -798,7 +798,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), - Row(null) :: Nil) + Row(Double.NaN) :: Nil) checkAnswer( sqlContext.sql( @@ -807,10 +807,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """.stripMargin), Row(1, null) :: Row(2, null) :: - Row(3, null) :: - Row(4, null) :: - Row(5, null) :: - Row(6, null) :: Nil) + Row(3, Double.NaN) :: + Row(4, Double.NaN) :: + Row(5, Double.NaN) :: + Row(6, Double.NaN) :: Nil) val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) @@ -841,11 +841,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // one row test val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") - val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) - assert(cov_samp3 == null) - - val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) - assert(cov_pop3 == 0.0) + checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN)) + checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } test("no aggregation function (SPARK-11486)") { From d0df2ca40953ba581dce199798a168af01283cdc Mon Sep 17 00:00:00 2001 From: Gabriele Nizzoli Date: Tue, 2 Feb 2016 13:20:01 -0800 Subject: [PATCH 102/131] [SPARK-13121][STREAMING] java mapWithState mishandles scala Option Already merged into 1.6 branch, this PR is to commit to master the same change Author: Gabriele Nizzoli Closes #11028 from gabrielenizzoli/patch-1. --- .../src/main/scala/org/apache/spark/streaming/StateSpec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 66f646d7dc13..e6724feaee10 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -221,7 +221,7 @@ object StateSpec { mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { - mappingFunction.call(k, Optional.ofNullable(v.get), s) + mappingFunction.call(k, JavaUtils.optionToOptional(v), s) } StateSpec.function(wrappedFunc) } From b377b03531d21b1d02a8f58b3791348962e1f31b Mon Sep 17 00:00:00 2001 From: "Kevin (Sangwoo) Kim" Date: Tue, 2 Feb 2016 13:24:09 -0800 Subject: [PATCH 103/131] [DOCS] Update StructType.scala The example will throw error like :20: error: not found: value StructType Need to add this line: import org.apache.spark.sql.types._ Author: Kevin (Sangwoo) Kim Closes #10141 from swkimme/patch-1. --- .../src/main/scala/org/apache/spark/sql/types/StructType.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c9e7e7fe633b..e797d83cb05b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParse * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * * val struct = * StructType( From 6de6a97728408ee2619006decf2267cc43eeea0d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 16:24:31 -0800 Subject: [PATCH 104/131] [SPARK-13150] [SQL] disable two flaky tests Author: Davies Liu Closes #11037 from davies/disable_flaky. --- .../sql/hive/thriftserver/HiveThriftServer2Suites.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ba3b26e1b7d4..9860e40fe854 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -488,7 +488,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } - test("SPARK-11595 ADD JAR with input path having URL scheme") { + // TODO: enable this + ignore("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" @@ -546,7 +547,8 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override protected def extraConf: Seq[String] = "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil - test("test single session") { + // TODO: enable this + ignore("test single session") { withMultipleConnectionJdbcStatement( { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" From 672032d0ab1e43bc5a25cecdb1b96dfd35c39778 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Feb 2016 08:26:35 +0800 Subject: [PATCH 105/131] [SPARK-13020][SQL][TEST] fix random generator for map type when we generate map, we first randomly pick a length, then create a seq of key value pair with the expected length, and finally call `toMap`. However, `toMap` will remove all duplicated keys, which makes the actual map size much less than we expected. This PR fixes this problem by put keys in a set first, to guarantee we have enough keys to build a map with expected length. Author: Wenchen Fan Closes #10930 from cloud-fan/random-generator. --- .../apache/spark/sql/RandomDataGenerator.scala | 18 ++++++++++++++---- .../spark/sql/RandomDataGeneratorSuite.scala | 11 +++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 55efea80d1a4..7c173cbceefe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -47,9 +47,9 @@ object RandomDataGenerator { */ private val PROBABILITY_OF_NULL: Float = 0.1f - private val MAX_STR_LEN: Int = 1024 - private val MAX_ARR_SIZE: Int = 128 - private val MAX_MAP_SIZE: Int = 128 + final val MAX_STR_LEN: Int = 1024 + final val MAX_ARR_SIZE: Int = 128 + final val MAX_MAP_SIZE: Int = 128 /** * Helper function for constructing a biased random number generator which returns "interesting" @@ -208,7 +208,17 @@ object RandomDataGenerator { forType(valueType, nullable = valueContainsNull, rand) ) yield { () => { - Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + val length = rand.nextInt(MAX_MAP_SIZE) + val keys = scala.collection.mutable.HashSet(Seq.fill(length)(keyGenerator()): _*) + // In case the number of different keys is not enough, set a max iteration to avoid + // infinite loop. + var count = 0 + while (keys.size < length && count < MAX_MAP_SIZE) { + keys += keyGenerator() + count += 1 + } + val values = Seq.fill(keys.size)(valueGenerator()) + keys.zip(values).toMap } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index b8ccdf7516d8..9fba7924e954 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -95,4 +95,15 @@ class RandomDataGeneratorSuite extends SparkFunSuite { } } + test("check size of generated map") { + val mapType = MapType(IntegerType, IntegerType) + for (seed <- 1 to 1000) { + val generator = RandomDataGenerator.forType( + mapType, nullable = false, rand = new Random(seed)).get + val maps = Seq.fill(100)(generator().asInstanceOf[Map[Int, Int]]) + val expectedTotalElements = 100 / 2 * RandomDataGenerator.MAX_MAP_SIZE + val deviation = math.abs(maps.map(_.size).sum - expectedTotalElements) + assert(deviation.toDouble / expectedTotalElements < 2e-1) + } + } } From 21112e8a14c042ccef4312079672108a1082a95e Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 2 Feb 2016 16:33:21 -0800 Subject: [PATCH 106/131] [SPARK-12992] [SQL] Update parquet reader to support more types when decoding to ColumnarBatch. This patch implements support for more types when doing the vectorized decode. There are a few more types remaining but they should be very straightforward after this. This code has a few copy and paste pieces but they are difficult to eliminate due to performance considerations. Specifically, this patch adds support for: - String, Long, Byte types - Dictionary encoding for those types. Author: Nong Li Closes #10908 from nongli/spark-12992. --- .../parquet/UnsafeRowParquetRecordReader.java | 146 ++++++++++++++-- .../parquet/VectorizedPlainValuesReader.java | 45 ++++- .../parquet/VectorizedRleValuesReader.java | 160 +++++++++++++++++- .../parquet/VectorizedValuesReader.java | 5 + .../execution/vectorized/ColumnVector.java | 7 +- .../parquet/ParquetEncodingSuite.scala | 82 +++++++++ 6 files changed, 424 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 17adfec32192..b5dddb9f11b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.List; +import org.apache.commons.lang.NotImplementedException; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; @@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException { int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { - switch (columnReaders[i].descriptor.getType()) { - case INT32: - columnReaders[i].readIntBatch(num, columnarBatch.column(i)); - break; - default: - throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType()); - } + columnReaders[i].readBatch(num, columnarBatch.column(i)); } rowsReturned += num; columnarBatch.setNumRows(num); @@ -237,7 +233,8 @@ private void initializeInternal() throws IOException { // TODO: Be extremely cautious in what is supported. Expand this. if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && - originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE && + originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && @@ -464,6 +461,11 @@ private final class ColumnReader { */ private boolean useDictionary; + /** + * If useDictionary is true, the staging vector used to decode the ids. + */ + private ColumnVector dictionaryIds; + /** * Maximum definition level for this column. */ @@ -587,9 +589,8 @@ private boolean next() throws IOException { /** * Reads `total` values from this columnReader into column. - * TODO: implement the other encodings. */ - private void readIntBatch(int total, ColumnVector column) throws IOException { + private void readBatch(int total, ColumnVector column) throws IOException { int rowId = 0; while (total > 0) { // Compute the number of values we want to read in this page. @@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException { leftInPage = (int)(endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0); - - // Remap the values if it is dictionary encoded. if (useDictionary) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(column.getInt(i))); + // Data is dictionary encoded. We will vector decode the ids and then resolve the values. + if (dictionaryIds == null) { + dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(total); + } + // Read and decode dictionary ids. + readIntBatch(rowId, num, dictionaryIds); + decodeDictionaryIds(rowId, num, column); + } else { + switch (descriptor.getType()) { + case INT32: + readIntBatch(rowId, num, column); + break; + case INT64: + readLongBatch(rowId, num, column); + break; + case BINARY: + readBinaryBatch(rowId, num, column); + break; + default: + throw new IOException("Unsupported type: " + descriptor.getType()); } } + valuesRead += num; rowId += num; total -= num; } } + /** + * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. + */ + private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + switch (descriptor.getType()) { + case INT32: + if (column.dataType() == DataTypes.IntegerType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case INT64: + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + break; + + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; + + default: + throw new NotImplementedException("Unsupported type: " + descriptor.getType()); + } + + if (dictionaryIds.numNulls() > 0) { + // Copy the NULLs over. + // TODO: we can improve this by decoding the NULLs directly into column. This would + // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then + // just do the ID remapping as above. + for (int i = 0; i < num; ++i) { + if (dictionaryIds.getIsNull(rowId + i)) { + column.putNull(rowId + i); + } + } + } + } + + /** + * For all the read*Batch functions, reads `num` values from this columnReader into column. It + * is guaranteed that num is smaller than the number of values left in the current page. + */ + + private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.IntegerType) { + defColumn.readIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); + } else if (column.dataType() == DataTypes.ByteType) { + defColumn.readBytes( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.LongType) { + defColumn.readLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.isArray()) { + defColumn.readBinarys( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readPage() throws IOException { DataPage page = pageReader.readPage(); // TODO: Why is this a visitor? diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index dac0c52ebd2c..cec2418e4603 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -18,10 +18,13 @@ import java.io.IOException; +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.unsafe.Platform; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. @@ -52,15 +55,53 @@ public void skip(int n) { } @Override - public void readIntegers(int total, ColumnVector c, int rowId) { + public final void readIntegers(int total, ColumnVector c, int rowId) { c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public int readInteger() { + public final void readLongs(int total, ColumnVector c, int rowId) { + c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + + @Override + public final void readBytes(int total, ColumnVector c, int rowId) { + for (int i = 0; i < total; i++) { + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + c.putInt(rowId + i, buffer[offset]); + offset += 4; + } + } + + @Override + public final int readInteger() { int v = Platform.getInt(buffer, offset); offset += 4; return v; } + + @Override + public final long readLong() { + long v = Platform.getLong(buffer, offset); + offset += 8; + return v; + } + + @Override + public final byte readByte() { + return (byte)readInteger(); + } + + @Override + public final void readBinary(int total, ColumnVector v, int rowId) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + int start = offset; + offset += len; + v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 493ec9deed49..9bfd74db3876 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,12 +17,16 @@ package org.apache.spark.sql.execution.datasources.parquet; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; import org.apache.parquet.column.values.bitpacking.Packer; import org.apache.parquet.io.ParquetDecodingException; +import org.apache.parquet.io.api.Binary; + +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -35,7 +39,8 @@ * - Definition/Repetition levels * - Dictionary ids. */ -public final class VectorizedRleValuesReader extends ValuesReader { +public final class VectorizedRleValuesReader extends ValuesReader + implements VectorizedValuesReader { // Current decoding mode. The encoded data contains groups of either run length encoded data // (RLE) or bit packed data. Each group contains a header that indicates which group it is and // the number of values in the group. @@ -121,6 +126,7 @@ public int readValueDictionaryId() { return readInteger(); } + @Override public int readInteger() { if (this.currentCount == 0) { this.readNextGroup(); } @@ -138,7 +144,9 @@ public int readInteger() { /** * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader * reads the definition levels and then will read from `data` for the non-null values. - * If the value is null, c will be populated with `nullValue`. + * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only + * necessary for readIntegers because we also use it to decode dictionaryIds and want to make + * sure it always has a value in range. * * This is a batched version of this logic: * if (this.readInt() == level) { @@ -180,6 +188,154 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } } + // TODO: can this code duplication be removed without a perf penalty? + public void readBytes(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBytes(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putByte(rowId + i, data.readByte()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readLongs(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readLongs(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, data.readLong()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readBinarys(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + c.putNotNulls(rowId, n); + data.readBinary(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putNotNull(rowId + i); + data.readBinary(1, c, rowId); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + + // The RLE reader implements the vectorized decoding interface when used to decode dictionary + // IDs. This is different than the above APIs that decodes definitions levels along with values. + // Since this is only used to decode dictionary IDs, only decoding integers is supported. + @Override + public void readIntegers(int total, ColumnVector c, int rowId) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + c.putInts(rowId, n, currentValue); + break; + case PACKED: + c.putInts(rowId, n, currentBuffer, currentBufferIdx); + currentBufferIdx += n; + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + @Override + public byte readByte() { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBytes(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readLongs(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBinary(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void skip(int n) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + /** * Reads the next varint encoded int. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 49a9ed83d590..b6ec7311c564 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -24,12 +24,17 @@ * TODO: merge this into parquet-mr. */ public interface VectorizedValuesReader { + byte readByte(); int readInteger(); + long readLong(); /* * Reads `total` values into `c` start at `c[rowId]` */ + void readBytes(int total, ColumnVector c, int rowId); void readIntegers(int total, ColumnVector c, int rowId); + void readLongs(int total, ColumnVector c, int rowId); + void readBinary(int total, ColumnVector c, int rowId); // TODO: add all the other parquet types. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a5bc506a65ac..0514252a8e53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -763,7 +763,12 @@ public final int appendStruct(boolean isNull) { /** * Returns the elements appended. */ - public int getElementsAppended() { return elementsAppended; } + public final int getElementsAppended() { return elementsAppended; } + + /** + * Returns true if this column is an array. + */ + public final boolean isArray() { return resultArray != null; } /** * Maximum number of rows that can be stored in this column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala new file mode 100644 index 000000000000..cef6b79a094d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils +import org.apache.spark.sql.test.SharedSQLContext + +// TODO: this needs a lot more testing but it's currently not easy to test with the parquet +// writer abstractions. Revisit. +class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext { + import testImplicits._ + + val ROW = ((1).toByte, 2, 3L, "abc") + val NULL_ROW = ( + null.asInstanceOf[java.lang.Byte], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[String]) + + test("All Types Dictionary") { + (1 :: 1000 :: Nil).foreach { n => { + withTempPath { dir => + List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getByte(i) == 1) + assert(batch.column(1).getInt(i) == 2) + assert(batch.column(2).getLong(i) == 3) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + i += 1 + } + reader.close() + } + }} + } + + test("All Types Null") { + (1 :: 100 :: Nil).foreach { n => { + withTempPath { dir => + val data = List.fill(n)(NULL_ROW).toDF + data.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getIsNull(i)) + assert(batch.column(1).getIsNull(i)) + assert(batch.column(2).getIsNull(i)) + assert(batch.column(3).getIsNull(i)) + i += 1 + } + reader.close() + }} + } + } +} From ff71261b651a7b289ea2312abd6075da8b838ed9 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Tue, 2 Feb 2016 19:35:33 -0800 Subject: [PATCH 107/131] [SPARK-13122] Fix race condition in MemoryStore.unrollSafely() https://issues.apache.org/jira/browse/SPARK-13122 A race condition can occur in MemoryStore's unrollSafely() method if two threads that return the same value for currentTaskAttemptId() execute this method concurrently. This change makes the operation of reading the initial amount of unroll memory used, performing the unroll, and updating the associated memory maps atomic in order to avoid this race condition. Initial proposed fix wraps all of unrollSafely() in a memoryManager.synchronized { } block. A cleaner approach might be introduce a mechanism that synchronizes based on task attempt ID. An alternative option might be to track unroll/pending unroll memory based on block ID rather than task attempt ID. Author: Adam Budde Closes #11012 from budde/master. --- .../org/apache/spark/storage/MemoryStore.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 76aaa782b952..024b660ce6a7 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -255,8 +255,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this task, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisTask + // Keep track of pending unroll memory reserved by this method. + var pendingMemoryReserved = 0L // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] @@ -266,6 +266,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } else { + pendingMemoryReserved += initialMemoryThreshold } // Unroll this block safely, checking whether we have exceeded our threshold periodically @@ -278,6 +280,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) + if (keepUnrolling) { + pendingMemoryReserved += amountToRequest + } // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -304,10 +309,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo // release the unroll memory yet. Instead, we transfer it to pending unroll memory // so `tryToPut` can further transfer it to normal storage memory later. // TODO: we can probably express this without pending unroll memory (SPARK-10907) - val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved - unrollMemoryMap(taskAttemptId) -= amountToTransferToPending + unrollMemoryMap(taskAttemptId) -= pendingMemoryReserved pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + pendingMemoryReserved } } else { // Otherwise, if we return an iterator, we can only release the unroll memory when From 99a6e3c1e8d580ce1cc497bd9362eaf16c597f77 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 19:47:44 -0800 Subject: [PATCH 108/131] [SPARK-12951] [SQL] support spilling in generated aggregate This PR add spilling support for generated TungstenAggregate. If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated). The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite Author: Davies Liu Closes #10998 from davies/gen_spilling. --- .../aggregate/TungstenAggregate.scala | 172 +++++++++++++++--- 1 file changed, 142 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a8a81d6d6574..f61db8594dab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -258,6 +258,7 @@ case class TungstenAggregate( // The name for HashMap private var hashMapTerm: String = _ + private var sorterTerm: String = _ /** * This is called by generated Java class, should be public. @@ -286,39 +287,98 @@ case class TungstenAggregate( GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } - /** - * Update peak execution memory, called in generated Java class. + * Called by generated Java class to finish the aggregate and return a KVIterator. */ - def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) val metrics = TaskContext.get().taskMetrics() - metrics.incPeakExecutionMemory(mapMemory) - } + metrics.incPeakExecutionMemory(peakMemory) - private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + if (sorter == null) { + // not spilled + return hashMap.iterator() + } - // create hashMap - val thisPlan = ctx.addReferenceObj("plan", this) - hashMapTerm = ctx.freshName("hashMap") - val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = newMutableProjection( + mergeExpr, + bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + subexpressionEliminationEnabled)() + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } - // Create a name for iterator from HashMap - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + false + } + } - // generate code for output - val keyTerm = ctx.freshName("aggKey") - val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + /** + * Generate the code for output. + */ + private def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + plan: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).gen(ctx) } ctx.INPUT_ROW = bufferTerm val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => @@ -348,7 +408,7 @@ case class TungstenAggregate( // This should be the last operator in a stage, we should output UnsafeRow directly val joinerTerm = ctx.freshName("unsafeRowJoiner") ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, - s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + s"$joinerTerm = $plan.createUnsafeJoiner();") val resultRow = ctx.freshName("resultRow") s""" UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); @@ -367,6 +427,23 @@ case class TungstenAggregate( ${consume(ctx, eval)} """ } + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + sorterTerm = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, @@ -374,10 +451,15 @@ case class TungstenAggregate( private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - $iterTerm = $hashMapTerm.iterator(); + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); } """) + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + s""" if (!$initAgg) { $initAgg = true; @@ -391,8 +473,10 @@ case class TungstenAggregate( $outputCode } - $thisPlan.updatePeakMemory($hashMapTerm); - $hashMapTerm.free(); + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } """ } @@ -425,14 +509,42 @@ case class TungstenAggregate( ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) } + val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) { + val countTerm = ctx.freshName("fallbackCounter") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + ("true", "", "") + } + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the + // hash map will return null for new key), we spill the hash map to disk to free memory, then + // continue to do in-memory aggregation and spilling until all the rows had been processed. + // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key ${keyCode.code} - UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + UnsafeRow $buffer = null; + if ($checkFallback) { + // try to get the buffer from hash map + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + } if ($buffer == null) { - // failed to allocate the first page - throw new OutOfMemoryError("No enough memory for aggregation"); + if ($sorterTerm == null) { + $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + } else { + $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + } + $resetCoulter + // the hash map had be spilled, it should have enough memory now, + // try to allocate buffer again. + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } } + $incCounter // evaluate aggregate function ${evals.map(_.code).mkString("\n")} From 0557146619868002e2f7ec3c121c30bbecc918fc Mon Sep 17 00:00:00 2001 From: Imran Younus Date: Tue, 2 Feb 2016 20:38:53 -0800 Subject: [PATCH 109/131] [SPARK-12732][ML] bug fix in linear regression train Fixed the bug in linear regression train for the case when the target variable is constant. The two cases for `fitIntercept=true` or `fitIntercept=false` should be treated differently. Author: Imran Younus Closes #10702 from iyounus/SPARK-12732_bug_fix_in_linear_regression_train. --- .../ml/regression/LinearRegression.scala | 66 ++++++----- .../ml/regression/LinearRegressionSuite.scala | 105 ++++++++++++++++++ 2 files changed, 146 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c54e08b2ad9a..e253f25c0ea6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -219,33 +219,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val yMean = ySummarizer.mean(0) - val yStd = math.sqrt(ySummarizer.variance(0)) - - // If the yStd is zero, then the intercept is yMean with zero coefficient; - // as a result, training is not needed. - if (yStd == 0.0) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - if (handlePersistence) instances.unpersist() - val coefficients = Vectors.sparse(numFeatures, Seq()) - val intercept = yMean - - val model = new LinearRegressionModel(uid, coefficients, intercept) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - model, - Array(0D), - $(featuresCol), - Array(0D)) - return copyValues(model.setSummary(trainingSummary)) + val rawYStd = math.sqrt(ySummarizer.variance(0)) + if (rawYStd == 0.0) { + if ($(fitIntercept) || yMean==0.0) { + // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with + // zero coefficient; as a result, training is not needed. + // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of + // the fitIntercept + if (yMean == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + if (handlePersistence) instances.unpersist() + val coefficients = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, coefficients, intercept) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + model, + Array(0D), + $(featuresCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) + } else { + require($(regParam) == 0.0, "The standard deviation of the label is zero. " + + "Model cannot be regularized.") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } } + // if y is constant (rawYStd is zero), then y cannot be scaled. In this case + // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm. + val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val featuresMean = featuresSummarizer.mean.toArray val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 273c882c2a47..81fc6603ccfe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -37,6 +37,8 @@ class LinearRegressionSuite @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ @transient var datasetWithWeight: DataFrame = _ + @transient var datasetWithWeightConstantLabel: DataFrame = _ + @transient var datasetWithWeightZeroLabel: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -92,6 +94,29 @@ class LinearRegressionSuite Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + df.const.label <- as.data.frame(cbind(A, b.const)) + */ + datasetWithWeightConstantLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + datasetWithWeightZeroLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -558,6 +583,86 @@ class LinearRegressionSuite } } + test("linear regression model with constant label") { + /* + R code: + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + Seq("auto", "l-bfgs", "normal").foreach { solver => + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightConstantLabel) + val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), + model1.coefficients(1)) + assert(actual1 ~== expected(idx) absTol 1e-4) + + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightZeroLabel) + val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), + model2.coefficients(1)) + assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) + idx += 1 + } + } + } + + test("regularized linear regression through origin with constant label") { + // The problem is ill-defined if fitIntercept=false, regParam is non-zero. + // An exception is thrown in this case. + Seq("auto", "l-bfgs", "normal").foreach { solver => + for (standardization <- Seq(false, true)) { + val model = new LinearRegression().setFitIntercept(false) + .setRegParam(0.1).setStandardization(standardization).setSolver(solver) + intercept[IllegalArgumentException] { + model.fit(datasetWithWeightConstantLabel) + } + } + } + } + + test("linear regression with l-bfgs when training is not needed") { + // When label is constant, l-bfgs solver returns results without training. + // There are two possibilities: If the label is non-zero but constant, + // and fitIntercept is true, then the model return yMean as intercept without training. + // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so + // no training is needed. + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightConstantLabel) + if (fitIntercept) { + assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightZeroLabel) + assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + } + } + test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer = new LinearRegression().setSolver(solver) From 335f10edad8c759bad3dbd0660ed4dd5d70ddd8b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 2 Feb 2016 21:13:54 -0800 Subject: [PATCH 110/131] [SPARK-7997][CORE] Add rpcEnv.awaitTermination() back to SparkEnv `rpcEnv.awaitTermination()` was not added in #10854 because some Streaming Python tests hung forever. This patch fixed the hung issue and added rpcEnv.awaitTermination() back to SparkEnv. Previously, Streaming Kafka Python tests shutdowns the zookeeper server before stopping StreamingContext. Then when stopping StreamingContext, KafkaReceiver may be hung due to https://issues.apache.org/jira/browse/KAFKA-601, hence, some thread of RpcEnv's Dispatcher cannot exit and rpcEnv.awaitTermination is hung.The patch just changed the shutdown order to fix it. Author: Shixiong Zhu Closes #11031 from zsxwing/awaitTermination. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 1 + python/pyspark/streaming/tests.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 12c7b2048a8c..9461afdc5412 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -91,6 +91,7 @@ class SparkEnv ( metricsSystem.stop() outputCommitCoordinator.stop() rpcEnv.shutdown() + rpcEnv.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 24b812615cbb..b33e8252a7d3 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1013,12 +1013,12 @@ def setUp(self): self._kafkaTestUtils.setup() def tearDown(self): + super(KafkaStreamTests, self).tearDown() + if self._kafkaTestUtils is not None: self._kafkaTestUtils.teardown() self._kafkaTestUtils = None - super(KafkaStreamTests, self).tearDown() - def _randomTopic(self): return "topic-%d" % random.randint(0, 10000) From e86f8f63bfa3c15659b94e831b853b1bc9ddae32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 22:13:10 -0800 Subject: [PATCH 111/131] [SPARK-13147] [SQL] improve readability of generated code 1. try to avoid the suffix (unique id) 2. remove the comment if there is no code generated. 3. re-arrange the order of functions 4. trop the new line for inlined blocks. Author: Davies Liu Closes #11032 from davies/better_suffix. --- .../sql/catalyst/expressions/Expression.scala | 8 +++-- .../expressions/codegen/CodeGenerator.scala | 27 ++++++++++------ .../expressions/complexTypeExtractors.scala | 31 +++++++++++-------- .../sql/execution/WholeStageCodegen.scala | 13 ++++---- .../aggregate/TungstenAggregate.scala | 14 ++++----- .../spark/sql/execution/basicOperators.scala | 7 ++++- .../BenchmarkWholeStageCodegen.scala | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 353fb92581d3..c73b2f8f2a31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] { val value = ctx.freshName("value") val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + if (ve.code != "") { + // Add `this` in the comment. + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + } else { + ve + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a30aba16170a..63e19564dd86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -156,7 +156,11 @@ class CodegenContext { /** The variable name of the input row in generated code. */ final var INPUT_ROW = "i" - private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * The map from a variable name to it's next ID. + */ + private val freshNameIds = new mutable.HashMap[String, Int] + freshNameIds += INPUT_ROW -> 1 /** * A prefix used to generate fresh name. @@ -164,16 +168,21 @@ class CodegenContext { var freshNamePrefix = "" /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Returns a term name that is unique within this instance of a `CodegenContext`. */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" + def freshName(name: String): String = synchronized { + val fullName = if (freshNamePrefix == "") { + name + } else { + s"${freshNamePrefix}_$name" + } + if (freshNameIds.contains(fullName)) { + val id = freshNameIds(fullName) + freshNameIds(fullName) = id + 1 + s"$fullName$id" } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" + freshNameIds += fullName -> 1 + fullName } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9f2f82d68cca..6b24fae9f3f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -173,22 +173,26 @@ case class GetArrayStructFields( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { + val n = ctx.freshName("n") + val values = ctx.freshName("values") + val j = ctx.freshName("j") + val row = ctx.freshName("row") s""" - final int n = $eval.numElements(); - final Object[] values = new Object[n]; - for (int j = 0; j < n; j++) { - if ($eval.isNullAt(j)) { - values[j] = null; + final int $n = $eval.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($eval.isNullAt($j)) { + $values[$j] = null; } else { - final InternalRow row = $eval.getStruct(j, $numFields); - if (row.isNullAt($ordinal)) { - values[j] = null; + final InternalRow $row = $eval.getStruct($j, $numFields); + if ($row.isNullAt($ordinal)) { + $values[$j] = null; } else { - values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } } - ${ev.value} = new $arrayClass(values); + ${ev.value} = new $arrayClass($values); """ }) } @@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") s""" - final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) { + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; + ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; } """ }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 02b0f423ed43..14754969072f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { s""" | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); - | ${columns.map(_.code).mkString("\n")} - | ${consume(ctx, columns)} + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} | } """.stripMargin } @@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public GeneratedIterator(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} + this.references = references; + ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + protected void processNext() throws java.io.IOException { - $code + ${code.trim} } } """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f61db8594dab..d0244770613d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -211,9 +211,9 @@ case class TungstenAggregate( | $doAgg(); | | // output the result - | $genResult + | ${genResult.trim} | - | ${consume(ctx, resultVars)} + | ${consume(ctx, resultVars).trim} | } """.stripMargin } @@ -242,9 +242,9 @@ case class TungstenAggregate( } s""" | // do aggregate - | ${aggVals.map(_.code).mkString("\n")} + | ${aggVals.map(_.code).mkString("\n").trim} | // update aggregation buffer - | ${updates.mkString("")} + | ${updates.mkString("\n").trim} """.stripMargin } @@ -523,7 +523,7 @@ case class TungstenAggregate( // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key - ${keyCode.code} + ${keyCode.code.trim} UnsafeRow $buffer = null; if ($checkFallback) { // try to get the buffer from hash map @@ -547,9 +547,9 @@ case class TungstenAggregate( $incCounter // evaluate aggregate function - ${evals.map(_.code).mkString("\n")} + ${evals.map(_.code).mkString("\n").trim} // update aggregate buffer - ${updates.mkString("\n")} + ${updates.mkString("\n").trim} """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fd81531c9316..ae4422195cc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit BindReferences.bindReference(condition, child.output)) ctx.currentVars = input val eval = expr.gen(ctx) + val nullCheck = if (expr.nullable) { + s"!${eval.isNull} &&" + } else { + s"" + } s""" | ${eval.code} - | if (!${eval.isNull} && ${eval.value}) { + | if ($nullCheck ${eval.value}) { | ${consume(ctx, ctx.currentVars)} | } """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 1ccf0e3d0656..ec2b9ab2cbad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // These benchmark are skipped in normal build ignore("benchmark") { // testWholeStage(200 << 20) - // testStddev(20 << 20) + // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } From 138c300f97d29cb0d04a70bea98a8a0c0548318a Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 2 Feb 2016 22:22:50 -0800 Subject: [PATCH 112/131] [SPARK-12957][SQL] Initial support for constraint propagation in SparkSQL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on the semantics of a query, we can derive a number of data constraints on output of each (logical or physical) operator. For instance, if a filter defines `‘a > 10`, we know that the output data of this filter satisfies 2 constraints: 1. `‘a > 10` 2. `isNotNull(‘a)` This PR proposes a possible way of keeping track of these constraints and propagating them in the logical plan, which can then help us build more advanced optimizations (such as pruning redundant filters, optimizing joins, among others). We define constraints as a set of (implicitly conjunctive) expressions. For e.g., if a filter operator has constraints = `Set(‘a > 10, ‘b < 100)`, it’s implied that the outputs satisfy both individual constraints (i.e., `‘a > 10` AND `‘b < 100`). Design Document: https://docs.google.com/a/databricks.com/document/d/1WQRgDurUBV9Y6CWOBS75PQIqJwT-6WftVa18xzm7nCo/edit?usp=sharing Author: Sameer Agarwal Closes #10844 from sameeragarwal/constraints. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 55 +++++- .../catalyst/plans/logical/LogicalPlan.scala | 2 + .../plans/logical/basicOperators.scala | 79 +++++++- .../plans/ConstraintPropagationSuite.scala | 173 ++++++++++++++++++ 4 files changed, 302 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b43b7ee71e7a..05f5bdbfc076 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -26,6 +26,56 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def output: Seq[Attribute] + /** + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. + */ + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + /** + * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions. + * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form + * `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints.map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) + } + + /** + * A sequence of expressions that describes the data property of the output rows of this + * operator. For example, if the output of this operator is column `a`, an example `constraints` + * can be `Set(a > 10, a < 20)`. + */ + lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]] + */ + protected def validConstraints: Set[Expression] = Set.empty + /** * Returns the set of attributes that are output by this node. */ @@ -59,6 +109,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { @@ -67,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { @@ -99,6 +151,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. * @return */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d859551f8c5..d8944a424156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -305,6 +305,8 @@ abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil + + override protected def validConstraints: Set[Expression] = child.constraints } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 16f4b355b1b6..8150ff843476 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -87,11 +87,27 @@ case class Generate( } } -case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { +case class Filter(condition: Expression, child: LogicalPlan) + extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + + override protected def validConstraints: Set[Expression] = { + child.constraints.union(splitConjunctivePredicates(condition).toSet) + } } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + + protected def leftConstraints: Set[Expression] = left.constraints + + protected def rightConstraints: Set[Expression] = { + require(left.output.size == right.output.size) + val attributeRewrites = AttributeMap(right.output.zip(left.output)) + right.constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } +} private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -106,6 +122,10 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override protected def validConstraints: Set[Expression] = { + leftConstraints.union(rightConstraints) + } + // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. override lazy val resolved: Boolean = @@ -119,6 +139,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override protected def validConstraints: Set[Expression] = leftConstraints + override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && @@ -157,13 +179,36 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } + + /** + * Maps the constraints containing a given (original) sequence of attributes to those with a + * given (reference) sequence of attributes. Given the nature of union, we expect that the + * mapping between the original and reference sequences are symmetric. + */ + private def rewriteConstraints( + reference: Seq[Attribute], + original: Seq[Attribute], + constraints: Set[Expression]): Set[Expression] = { + require(reference.size == original.size) + val attributeRewrites = AttributeMap(original.zip(reference)) + constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + + override protected def validConstraints: Set[Expression] = { + children + .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) + .reduce(_ intersect _) + } } case class Join( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { joinType match { @@ -180,6 +225,28 @@ case class Join( } } + override protected def validConstraints: Set[Expression] = { + joinType match { + case Inner if condition.isDefined => + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + case LeftSemi if condition.isDefined => + left.constraints + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftSemi => + left.constraints + case LeftOuter => + left.constraints + case RightOuter => + right.constraints + case FullOuter => + Set.empty[Expression] + } + } + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala new file mode 100644 index 000000000000..b5cf91394d91 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class ConstraintPropagationSuite extends SparkFunSuite { + + private def resolveColumn(tr: LocalRelation, columnName: String): Expression = + tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + + private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { + val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) + val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } + } + + test("propagating constraints in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a")))) + + verifyConstraints(tr + .where('a.attr > 10) + .select('c.attr, 'a.attr) + .where('c.attr < 100) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") < 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c")))) + } + + test("propagating constraints in union") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + + assert(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('e.attr > 10) + .unionAll(tr3.where('i.attr > 10))) + .analyze.constraints.isEmpty) + + verifyConstraints(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in intersect") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr1 + .where('a.attr > 10) + .intersect(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b")))) + } + + test("propagating constraints in except") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 + .where('a.attr > 10) + .except(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in inner join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-semi join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in right-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in full-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + assert(tr1.where('a.attr > 10) + .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints.isEmpty) + } +} From e9eb248edfa81d75f99c9afc2063e6b3d9ee7392 Mon Sep 17 00:00:00 2001 From: Mario Briggs Date: Wed, 3 Feb 2016 09:50:28 -0800 Subject: [PATCH 113/131] [SPARK-12739][STREAMING] Details of batch in Streaming tab uses two Duration columns I have clearly prefix the two 'Duration' columns in 'Details of Batch' Streaming tab as 'Output Op Duration' and 'Job Duration' Author: Mario Briggs Author: mariobriggs Closes #11022 from mariobriggs/spark-12739. --- .../main/scala/org/apache/spark/streaming/ui/BatchPage.scala | 4 ++-- .../scala/org/apache/spark/streaming/UISeleniumSuite.scala | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 7635f79a3d2d..81de07f933f8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -37,10 +37,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private def columns: Seq[Node] = { Output Op Id Description - Duration + Output Op Duration Status Job Id - Duration + Job Duration Stages: Succeeded/Total Tasks (for all stages): Succeeded/Total Error diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index c4ecebcacf3c..96dd4757be85 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -143,8 +143,9 @@ class UISeleniumSuite summaryText should contain ("Total delay:") findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { - List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration", - "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") + List("Output Op Id", "Description", "Output Op Duration", "Status", "Job Id", + "Job Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", + "Error") } // Check we have 2 output op ids From c4feec26eb677bfd3bfac38e5e28eae05279956e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Feb 2016 10:38:53 -0800 Subject: [PATCH 114/131] [SPARK-12798] [SQL] generated BroadcastHashJoin A row from stream side could match multiple rows on build side, the loop for these matched rows should not be interrupted when emitting a row, so we buffer the output rows in a linked list, check the termination condition on producer loop (for example, Range or Aggregate). Author: Davies Liu Closes #10989 from davies/gen_join. --- .../sql/execution/BufferedRowIterator.java | 30 ++++-- .../sql/execution/WholeStageCodegen.scala | 18 ++-- .../aggregate/TungstenAggregate.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 2 + .../execution/joins/BroadcastHashJoin.scala | 92 ++++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 28 +++++- .../execution/WholeStageCodegenSuite.scala | 15 ++- 7 files changed, 169 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index 6acf70dbbad0..ea20115770f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -18,9 +18,11 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.LinkedList; import scala.collection.Iterator; +import org.apache.spark.TaskContext; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -31,28 +33,42 @@ * TODO: replaced it by batched columnar format. */ public class BufferedRowIterator { - protected InternalRow currentRow; + protected LinkedList currentRows = new LinkedList<>(); protected Iterator input; // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); public boolean hasNext() throws IOException { - if (currentRow == null) { + if (currentRows.isEmpty()) { processNext(); } - return currentRow != null; + return !currentRows.isEmpty(); } public InternalRow next() { - InternalRow r = currentRow; - currentRow = null; - return r; + return currentRows.remove(); } public void setInput(Iterator iter) { input = iter; } + /** + * Returns whether `processNext()` should stop processing next row from `input` or not. + * + * If it returns true, the caller should exit the loop (return from processNext()). + */ + protected boolean shouldStop() { + return !currentRows.isEmpty(); + } + + /** + * Increase the peak execution memory for current task. + */ + protected void incPeakExecutionMemory(long size) { + TaskContext.get().taskMetrics().incPeakExecutionMemory(size); + } + /** * Processes the input until have a row as output (currentRow). * @@ -60,7 +76,7 @@ public void setInput(Iterator iter) { */ protected void processNext() throws IOException { if (input.hasNext()) { - currentRow = input.next(); + currentRows.add(input.next()); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 14754969072f..131efea20f31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} import org.apache.spark.util.Utils /** @@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} + | if (shouldStop()) { + | return; + | } | } """.stripMargin } @@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRow = $row; - | return; + | currentRows.add($row.copy()); """.stripMargin } else { assert(input != null) @@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" | ${code.code.trim} - | currentRow = ${code.value}; - | return; + | currentRows.add(${code.value}.copy()); """.stripMargin } else { // There is no columns s""" - | currentRow = unsafeRow; - | return; + | currentRows.add(unsafeRow); """.stripMargin } } @@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { + // The build side can't be compiled together + case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + b.copy(left = apply(left)) + case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d0244770613d..9d9f14f2dd01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -471,6 +471,8 @@ case class TungstenAggregate( UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); $outputCode + + if (shouldStop()) return; } $iterTerm.close(); @@ -480,7 +482,7 @@ case class TungstenAggregate( """ } - private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key ctx.currentVars = input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ae4422195cc4..6e51c4d84824 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -237,6 +237,8 @@ case class Range( | $overflow = true; | } | ${consume(ctx, Seq(ev))} + | + | if (shouldStop()) return; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 04640711d99d..8b275e886c46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -42,7 +45,7 @@ case class BroadcastHashJoin( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryNode with HashJoin { + extends BinaryNode with HashJoin with CodegenSupport { override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), @@ -117,6 +120,87 @@ case class BroadcastHashJoin( hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + + // the term for hash relation + private var relationTerm: String = _ + + override def upstream(): RDD[InternalRow] = { + streamedPlan.asInstanceOf[CodegenSupport].upstream() + } + + override def doProduce(ctx: CodegenContext): String = { + // create a name for HashRelation + val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + relationTerm = ctx.freshName("relation") + // TODO: create specialized HashRelation for single join key + val clsName = classOf[UnsafeHashedRelation].getName + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = ($clsName) $broadcast.value(); + | incPeakExecutionMemory($relationTerm.getUnsafeSize()); + """.stripMargin) + + s""" + | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} + """.stripMargin + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // generate the key as UnsafeRow + ctx.currentVars = input + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val keyTerm = keyVal.value + val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + + // find the matches from HashedRelation + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val row = ctx.freshName("row") + + // create variables for output + ctx.currentVars = null + ctx.INPUT_ROW = row + val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + val resultVars = buildSide match { + case BuildLeft => buildColumns ++ input + case BuildRight => input ++ buildColumns + } + + val ouputCode = if (condition.isDefined) { + // filter the output via condition + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + | ${ev.code} + | if (!${ev.isNull} && ${ev.value}) { + | ${consume(ctx, resultVars)} + | } + """.stripMargin + } else { + consume(ctx, resultVars) + } + + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $row = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $ouputCode + | } + | } + """.stripMargin + } } object BroadcastHashJoin { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index ec2b9ab2cbad..15ba77353109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.functions._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testBroadcastHashJoin(values: Int): Unit = { + val benchmark = new Benchmark("BroadcastHashJoin", values) + + val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + + benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X + BroadcastHashJoin w codegen 1028.40 10.20 2.97 X + */ + benchmark.run() + } + def testBytesToBytesMap(values: Int): Unit = { val benchmark = new Benchmark("BytesToBytesMap", values) @@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // testWholeStage(200 << 20) // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) - // testBytesToBytesMap(1024 * 1024 * 50) + // testBytesToBytesMap(50 << 20) + // testBroadcastHashJoin(10 << 20) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c2516509dfbb..9350205d791d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.functions.{avg, col, max} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } + + test("BroadcastHashJoin should be included in WholeStageCodegen") { + val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val schema = new StructType().add("k", IntegerType).add("v", StringType) + val smallDF = sqlContext.createDataFrame(rdd, schema) + val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + assert(df.queryExecution.executedPlan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) + assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) + } } From 9dd2741ebe5f9b5fa0a3b0e9c594d0e94b6226f9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 3 Feb 2016 12:31:30 -0800 Subject: [PATCH 115/131] [SPARK-13157] [SQL] Support any kind of input for SQL commands. The ```SparkSqlLexer``` currently swallows characters which have not been defined in the grammar. This causes problems with SQL commands, such as: ```add jar file:///tmp/ab/TestUDTF.jar```. In this example the `````` is swallowed. This PR adds an extra Lexer rule to handle such input, and makes a tiny modification to the ```ASTNode```. cc davies liancheng Author: Herman van Hovell Closes #11052 from hvanhovell/SPARK-13157. --- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 ++ .../spark/sql/catalyst/parser/ASTNode.scala | 4 +- .../sql/catalyst/parser/ASTNodeSuite.scala | 38 +++++++++++++++++++ .../HiveThriftServer2Suites.scala | 6 +-- 4 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e930caa291d4..1d07a27353dc 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -483,3 +483,7 @@ COMMENT { $channel=HIDDEN; } ; +/* Prevent that the lexer swallows unknown characters. */ +ANY + :. + ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala index ec9812414e19..28f7b10ed6a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -58,12 +58,12 @@ case class ASTNode( override val origin: Origin = Origin(Some(line), Some(positionInLine)) /** Source text. */ - lazy val source: String = stream.toString(startIndex, stopIndex) + lazy val source: String = stream.toOriginalString(startIndex, stopIndex) /** Get the source text that remains after this token. */ lazy val remainder: String = { stream.fill() - stream.toString(stopIndex + 1, stream.size() - 1).trim() + stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim() } def text: String = token.getText diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala new file mode 100644 index 000000000000..8b05f9e33d69 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ASTNodeSuite extends SparkFunSuite { + test("SPARK-13157 - remainder must return all input chars") { + val inputs = Seq( + ("add jar", "file:///tmp/ab/TestUDTF.jar"), + ("add jar", "file:///tmp/a@b/TestUDTF.jar"), + ("add jar", "c:\\windows32\\TestUDTF.jar"), + ("add jar", "some \nbad\t\tfile\r\n.\njar"), + ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"), + ("SET", "foo=bar"), + ("SET", "foo*)(@#^*@&!#^=bar") + ) + inputs.foreach { + case (command, arguments) => + val node = ParseDriver.parsePlan(s"$command $arguments", null) + assert(node.remainder === arguments) + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 9860e40fe854..ba3b26e1b7d4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -488,8 +488,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } - // TODO: enable this - ignore("SPARK-11595 ADD JAR with input path having URL scheme") { + test("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" @@ -547,8 +546,7 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override protected def extraConf: Seq[String] = "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil - // TODO: enable this - ignore("test single session") { + test("test single session") { withMultipleConnectionJdbcStatement( { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" From 3221eddb8f9728f65c579969a3a88baeeb7577a9 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 3 Feb 2016 15:53:10 -0800 Subject: [PATCH 116/131] [SPARK-3611][WEB UI] Show number of cores for each executor in application web UI Added a Cores column in the Executors UI Author: Alex Bozarth Closes #11039 from ajbozarth/spark3611. --- .../main/scala/org/apache/spark/status/api/v1/api.scala | 1 + .../scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 7 +++++++ .../main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 5 +++-- .../executor_list_json_expectation.json | 1 + 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2b0079f5fd62..d116e68c17f1 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -57,6 +57,7 @@ class ExecutorSummary private[spark]( val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, + val totalCores: Int, val maxTasks: Int, val activeTasks: Int, val failedTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index e36b96b3e697..e1f754999912 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -75,6 +75,7 @@ private[ui] class ExecutorsPage( RDD Blocks Storage Memory Disk Used + Cores Active Tasks Failed Tasks Complete Tasks @@ -131,6 +132,7 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} + {info.totalCores} {taskData(info.maxTasks, info.activeTasks, info.failedTasks, info.completedTasks, info.totalTasks, info.totalDuration, info.totalGCTime)} @@ -174,6 +176,7 @@ private[ui] class ExecutorsPage( val maximumMemory = execInfo.map(_.maxMemory).sum val memoryUsed = execInfo.map(_.memoryUsed).sum val diskUsed = execInfo.map(_.diskUsed).sum + val totalCores = execInfo.map(_.totalCores).sum val totalInputBytes = execInfo.map(_.totalInputBytes).sum val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum @@ -188,6 +191,7 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} + {totalCores} {taskData(execInfo.map(_.maxTasks).sum, execInfo.map(_.activeTasks).sum, execInfo.map(_.failedTasks).sum, @@ -211,6 +215,7 @@ private[ui] class ExecutorsPage( RDD Blocks Storage Memory Disk Used + Cores Active Tasks Failed Tasks Complete Tasks @@ -305,6 +310,7 @@ private[spark] object ExecutorsPage { val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed + val totalCores = listener.executorToTotalCores.getOrElse(execId, 0) val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0) val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0) val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) @@ -323,6 +329,7 @@ private[spark] object ExecutorsPage { rddBlocks, memUsed, diskUsed, + totalCores, maxTasks, activeTasks, failedTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index a9e926b15878..dcfebe92ed80 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -45,6 +45,7 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { + val executorToTotalCores = HashMap[String, Int]() val executorToTasksMax = HashMap[String, Int]() val executorToTasksActive = HashMap[String, Int]() val executorToTasksComplete = HashMap[String, Int]() @@ -65,8 +66,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap - executorToTasksMax(eid) = - executorAdded.executorInfo.totalCores / conf.getInt("spark.task.cpus", 1) + executorToTotalCores(eid) = executorAdded.executorInfo.totalCores + executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1) executorIdToData(eid) = ExecutorUIData(executorAdded.time) } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 94f8aeac55b5..9d5d224e5517 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -4,6 +4,7 @@ "rddBlocks" : 8, "memoryUsed" : 28000128, "diskUsed" : 0, + "totalCores" : 0, "maxTasks" : 0, "activeTasks" : 0, "failedTasks" : 1, From 915a75398ecbccdbf9a1e07333104c857ae1ce5e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Feb 2016 16:10:11 -0800 Subject: [PATCH 117/131] [SPARK-13166][SQL] Remove DataStreamReader/Writer They seem redundant and we can simply use DataFrameReader/Writer. The new usage looks like: ```scala val df = sqlContext.read.stream("...") val handle = df.write.stream("...") handle.stop() ``` Author: Reynold Xin Closes #11062 from rxin/SPARK-13166. --- .../org/apache/spark/sql/DataFrame.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 29 +++- .../apache/spark/sql/DataFrameWriter.scala | 36 ++++- .../apache/spark/sql/DataStreamReader.scala | 127 ----------------- .../apache/spark/sql/DataStreamWriter.scala | 134 ------------------ .../org/apache/spark/sql/SQLContext.scala | 11 +- .../datasources/ResolvedDataSource.scala | 1 - .../sql/streaming/DataStreamReaderSuite.scala | 53 ++++--- 8 files changed, 86 insertions(+), 315 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6de17e5924d0..84203bbfef66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1682,7 +1682,7 @@ class DataFrame private[sql]( /** * :: Experimental :: - * Interface for saving the content of the [[DataFrame]] out into external storage. + * Interface for saving the content of the [[DataFrame]] out into external storage or streams. * * @group output * @since 1.4.0 @@ -1690,14 +1690,6 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) - /** - * :: Experimental :: - * Interface for starting a streaming query that will continually output results to the specified - * external sink as new data arrives. - */ - @Experimental - def streamTo: DataStreamWriter = new DataStreamWriter(this) - /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2e0c6c7df967..a58643a5ba15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,17 +29,17 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystQl} import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType /** * :: Experimental :: * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc). Use [[SQLContext.read]] to access this. + * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this. * * @since 1.4.0 */ @@ -136,6 +136,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } + /** + * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path + * (e.g. external key-value stores). + * + * @since 2.0.0 + */ + def stream(): DataFrame = { + val resolved = ResolvedDataSource.createSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + providerName = source, + options = extraOptions.toMap) + DataFrame(sqlContext, StreamingRelation(resolved)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def stream(path: String): DataFrame = { + option("path", path).stream() + } + /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. @@ -165,7 +189,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 12eb2393634a..80447fefe1f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,17 +22,18 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.sources.HadoopFsRelation /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc). Use [[DataFrame.write]] to access this. + * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. * * @since 1.4.0 */ @@ -183,6 +184,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { df) } + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def stream(path: String): ContinuousQuery = { + option("path", path).stream() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def stream(): ContinuousQuery = { + val sink = ResolvedDataSource.createSink( + df.sqlContext, + source, + extraOptions.toMap, + normalizedParCols.getOrElse(Nil)) + + new StreamExecution(df.sqlContext, df.logicalPlan, sink) + } + /** * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. @@ -255,7 +284,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * The given column name may not be equal to any of the existing column names if we were in - * case-insensitive context. Normalize the given column name to the real one so that we don't + * case-insensitive context. Normalize the given column name to the real one so that we don't * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { @@ -339,7 +368,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala deleted file mode 100644 index 2febc93fa49d..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.datasources.ResolvedDataSource -import org.apache.spark.sql.execution.streaming.StreamingRelation -import org.apache.spark.sql.types.StructType - -/** - * :: Experimental :: - * An interface to reading streaming data. Use `sqlContext.streamFrom` to access these methods. - * - * {{{ - * val df = sqlContext.streamFrom - * .format("...") - * .open() - * }}} - */ -@Experimental -class DataStreamReader private[sql](sqlContext: SQLContext) extends Logging { - - /** - * Specifies the input data source format. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamReader = { - this.source = source - this - } - - /** - * Specifies the input schema. Some data streams (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data stream can - * skip the schema inference step, and thus speed up data reading. - * - * @since 2.0.0 - */ - def schema(schema: StructType): DataStreamReader = { - this.userSpecifiedSchema = Option(schema) - this - } - - /** - * Adds an input option for the underlying data stream. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamReader = { - this.extraOptions += (key -> value) - this - } - - /** - * (Scala-specific) Adds input options for the underlying data stream. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { - this.extraOptions ++= options - this - } - - /** - * Adds input options for the underlying data stream. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { - this.options(options.asScala) - this - } - - /** - * Loads streaming input in as a [[DataFrame]], for data streams that don't require a path (e.g. - * external key-value stores). - * - * @since 2.0.0 - */ - def open(): DataFrame = { - val resolved = ResolvedDataSource.createSource( - sqlContext, - userSpecifiedSchema = userSpecifiedSchema, - providerName = source, - options = extraOptions.toMap) - DataFrame(sqlContext, StreamingRelation(resolved)) - } - - /** - * Loads input in as a [[DataFrame]], for data streams that read from some path. - * - * @since 2.0.0 - */ - def open(path: String): DataFrame = { - option("path", path).open() - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = sqlContext.conf.defaultDataSourceName - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = new scala.collection.mutable.HashMap[String, String] - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala deleted file mode 100644 index b325d48fcbbb..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.datasources.ResolvedDataSource -import org.apache.spark.sql.execution.streaming.StreamExecution - -/** - * :: Experimental :: - * Interface used to start a streaming query query execution. - * - * @since 2.0.0 - */ -@Experimental -final class DataStreamWriter private[sql](df: DataFrame) { - - /** - * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamWriter = { - this.source = source - this - } - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamWriter = { - this.extraOptions += (key -> value) - this - } - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter = { - this.extraOptions ++= options - this - } - - /** - * Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter = { - this.options(options.asScala) - this - } - - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme.\ - * @since 2.0.0 - */ - @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter = { - this.partitioningColumns = colNames - this - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * @since 2.0.0 - */ - def start(path: String): ContinuousQuery = { - this.extraOptions += ("path" -> path) - start() - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - def start(): ContinuousQuery = { - val sink = ResolvedDataSource.createSink( - df.sqlContext, - source, - extraOptions.toMap, - normalizedParCols) - - new StreamExecution(df.sqlContext, df.logicalPlan, sink) - } - - private def normalizedParCols: Seq[String] = { - partitioningColumns.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = df.sqlContext.conf.defaultDataSourceName - - private var extraOptions = new scala.collection.mutable.HashMap[String, String] - - private var partitioningColumns: Seq[String] = Nil - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 13700be06828..1661fdbec532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -579,10 +579,9 @@ class SQLContext private[sql]( DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) } - /** * :: Experimental :: - * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. * {{{ * sqlContext.read.parquet("/path/to/file.parquet") * sqlContext.read.schema(schema).json("/path/to/file.json") @@ -594,14 +593,6 @@ class SQLContext private[sql]( @Experimental def read: DataFrameReader = new DataFrameReader(this) - - /** - * :: Experimental :: - * Returns a [[DataStreamReader]] than can be used to access data continuously as it arrives. - */ - @Experimental - def streamFrom: DataStreamReader = new DataStreamReader(this) - /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index e3065ac5f87d..7702f535ad2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -122,7 +122,6 @@ object ResolvedDataSource extends Logging { provider.createSink(sqlContext, options, partitionColumns) } - /** Create a [[ResolvedDataSource]] for reading data in. */ def apply( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala index 1dab6ebf1bee..b36b41cac9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -60,22 +60,22 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { import testImplicits._ test("resolve default source") { - sqlContext.streamFrom + sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open() - .streamTo + .stream() + .write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() } test("resolve full class") { - sqlContext.streamFrom + sqlContext.read .format("org.apache.spark.sql.streaming.test.DefaultSource") - .open() - .streamTo + .stream() + .write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() } @@ -83,12 +83,12 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .open() + .stream() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -96,12 +96,12 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { LastOptions.parameters = null - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .start() + .stream() .stop() assert(LastOptions.parameters("opt1") == "1") @@ -110,54 +110,53 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { } test("partitioning") { - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open() + .stream() - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Nil) - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("a") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Seq("a")) - withSQLConf("spark.sql.caseSensitive" -> "false") { - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("A") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Seq("a")) } intercept[AnalysisException] { - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("b") - .start() + .stream() .stop() } } test("stream paths") { - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open("/test") + .stream("/test") assert(LastOptions.parameters("path") == "/test") LastOptions.parameters = null - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") - .start("/test") + .stream("/test") .stop() assert(LastOptions.parameters("path") == "/test") From de0914522fc5b2658959f9e2272b4e3162b14978 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Feb 2016 17:07:27 -0800 Subject: [PATCH 118/131] [SPARK-13131] [SQL] Use best and average time in benchmark Best time is stabler than average time, also added a column for nano seconds per row (which could be used to estimate contributions of each components in a query). Having best time and average time together for more information (we can see kind of variance). rate, time per row and relative are all calculated using best time. The result looks like this: ``` Intel(R) Core(TM) i7-4558U CPU 2.80GHz rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X ``` Author: Davies Liu Closes #11018 from davies/gen_bench. --- .../org/apache/spark/util/Benchmark.scala | 38 +++-- .../BenchmarkWholeStageCodegen.scala | 154 ++++++++---------- 2 files changed, 89 insertions(+), 103 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index d484cec7ae38..d1699f5c2865 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.SystemUtils @@ -59,17 +60,21 @@ private[spark] class Benchmark( } println - val firstRate = results.head.avgRate + val firstBest = results.head.bestMs + val firstAvg = results.head.avgMs // The results are going to be processor specific so it is useful to include that. println(Benchmark.getProcessorName()) - printf("%-30s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") - println("-------------------------------------------------------------------------------") - results.zip(benchmarks).foreach { r => - printf("%-30s %16s %16s %14s\n", - r._2.name, - "%10.2f" format r._1.avgMs, - "%10.2f" format r._1.avgRate, - "%6.2f X" format (r._1.avgRate / firstRate)) + printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", + "Per Row(ns)", "Relative") + println("-----------------------------------------------------------------------------------" + + "--------") + results.zip(benchmarks).foreach { case (result, benchmark) => + printf("%-35s %16s %12s %13s %10s\n", + benchmark.name, + "%5.0f / %4.0f" format (result.bestMs, result.avgMs), + "%10.1f" format result.bestRate, + "%6.1f" format (1000 / result.bestRate), + "%3.1fX" format (firstBest / result.bestMs)) } println // scalastyle:on @@ -78,7 +83,7 @@ private[spark] class Benchmark( private[spark] object Benchmark { case class Case(name: String, fn: Int => Unit) - case class Result(avgMs: Double, avgRate: Double) + case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** * This should return a user helpful processor information. Getting at this depends on the OS. @@ -99,22 +104,27 @@ private[spark] object Benchmark { * the rate of the function. */ def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { - var totalTime = 0L + val runTimes = ArrayBuffer[Long]() for (i <- 0 until iters + 1) { val start = System.nanoTime() f(i) val end = System.nanoTime() - if (i != 0) totalTime += end - start + val runTime = end - start + if (i > 0) { + runTimes += runTime + } if (outputPerIteration) { // scalastyle:off - println(s"Iteration $i took ${(end - start) / 1000} microseconds") + println(s"Iteration $i took ${runTime / 1000} microseconds") // scalastyle:on } } - Result(totalTime.toDouble / 1000000 / iters, num * iters / (totalTime.toDouble / 1000)) + val best = runTimes.min + val avg = runTimes.sum / iters + Result(avg / 1000000, num / (best / 1000), best / 1000000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 15ba77353109..33d4976403d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -34,54 +34,47 @@ import org.apache.spark.util.Benchmark */ class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + .set("spark.sql.shuffle.partitions", "1") lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) - def testWholeStage(values: Int): Unit = { - val benchmark = new Benchmark("rang/filter/aggregate", values) + def runBenchmark(name: String, values: Int)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, values) - benchmark.addCase("Without codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).filter("(id & 1) = 1").count() - } - - benchmark.addCase("With codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).filter("(id & 1) = 1").count() + Seq(false, true).foreach { enabled => + benchmark.addCase(s"$name codegen=$enabled") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) + f + } } - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - Without codegen 7775.53 26.97 1.00 X - With codegen 342.15 612.94 22.73 X - */ benchmark.run() } - def testStatFunctions(values: Int): Unit = { - - val benchmark = new Benchmark("stat functions", values) - - benchmark.addCase("stddev w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + // These benchmark are skipped in normal build + ignore("range/filter/sum") { + val N = 500 << 20 + runBenchmark("rang/filter/sum", N) { + sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect() } + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X + rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X + */ + } - benchmark.addCase("stddev w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() - } + ignore("stat functions") { + val N = 100 << 20 - benchmark.addCase("kurtosis w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + runBenchmark("stddev", N) { + sqlContext.range(N).groupBy().agg("id" -> "stddev").collect() } - benchmark.addCase("kurtosis w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + runBenchmark("kurtosis", N) { + sqlContext.range(N).groupBy().agg("id" -> "kurtosis").collect() } @@ -99,64 +92,56 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Using DeclarativeAggregate: Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - stddev w/o codegen 989.22 21.20 1.00 X - stddev w codegen 352.35 59.52 2.81 X - kurtosis w/o codegen 3636.91 5.77 0.27 X - kurtosis w codegen 369.25 56.79 2.68 X + stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + stddev codegen=false 5630 / 5776 18.0 55.6 1.0X + stddev codegen=true 1259 / 1314 83.0 12.0 4.5X + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X + kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X */ - benchmark.run() } - def testAggregateWithKey(values: Int): Unit = { - val benchmark = new Benchmark("Aggregate with keys", values) + ignore("aggregate with keys") { + val N = 20 << 20 - benchmark.addCase("Aggregate w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() - } - benchmark.addCase(s"Aggregate w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + runBenchmark("Aggregate w keys", N) { + sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - Aggregate w/o codegen 4254.38 4.93 1.00 X - Aggregate w codegen 2661.45 7.88 1.60 X + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X + Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X */ - benchmark.run() } - def testBroadcastHashJoin(values: Int): Unit = { - val benchmark = new Benchmark("BroadcastHashJoin", values) - + ignore("broadcast hash join") { + val N = 20 << 20 val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) - benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() - } - benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + runBenchmark("BroadcastHashJoin", N) { + sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() } /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X - BroadcastHashJoin w codegen 1028.40 10.20 2.97 X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X + BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X */ - benchmark.run() } - def testBytesToBytesMap(values: Int): Unit = { - val benchmark = new Benchmark("BytesToBytesMap", values) + ignore("hash and BytesToBytesMap") { + val N = 50 << 20 + + val benchmark = new Benchmark("BytesToBytesMap", N) benchmark.addCase("hash") { iter => var i = 0 @@ -167,7 +152,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(2) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 - while (i < values) { + while (i < N) { key.setInt(0, i % 1000) val h = Murmur3_x86_32.hashUnsafeWords( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) @@ -194,7 +179,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(2) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < values) { + while (i < N) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) if (loc.isDefined) { @@ -212,21 +197,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - hash 662.06 79.19 1.00 X - BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X - BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + hash 628 / 661 83.0 12.0 1.0X + BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X + BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X */ benchmark.run() } - - // These benchmark are skipped in normal build - ignore("benchmark") { - // testWholeStage(200 << 20) - // testStatFunctions(20 << 20) - // testAggregateWithKey(20 << 20) - // testBytesToBytesMap(50 << 20) - // testBroadcastHashJoin(10 << 20) - } } From a8e2ba776b20c8054918af646d8228bba1b87c9b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 3 Feb 2016 17:43:14 -0800 Subject: [PATCH 119/131] [SPARK-13152][CORE] Fix task metrics deprecation warning Make an internal non-deprecated version of incBytesRead and incRecordsRead so we don't have unecessary deprecation warnings in our build. Right now incBytesRead and incRecordsRead are marked as deprecated and for internal use only. We should make private[spark] versions which are not deprecated and switch to those internally so as to not clutter up the warning messages when building. cc andrewor14 who did the initial deprecation Author: Holden Karau Closes #11056 from holdenk/SPARK-13152-fix-task-metrics-deprecation-warnings. --- core/src/main/scala/org/apache/spark/CacheManager.scala | 4 ++-- .../main/scala/org/apache/spark/executor/InputMetrics.scala | 5 +++++ core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- core/src/main/scala/org/apache/spark/util/JsonProtocol.scala | 4 ++-- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index fa8e2b953835..923ff411ce25 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,12 +44,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(blockResult) => // Partition is already materialized, so just return its values val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) - existingMetrics.incBytesRead(blockResult.bytes) + existingMetrics.incBytesReadInternal(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] new InterruptibleIterator[T](context, iter) { override def next(): T = { - existingMetrics.incRecordsRead(1) + existingMetrics.incRecordsReadInternal(1) delegate.next() } } diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index ed9e157ce758..6d30d3c76a9f 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -81,10 +81,15 @@ class InputMetrics private ( */ def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) + // Once incBytesRead & intRecordsRead is ready to be removed from the public API + // we can remove the internal versions and make the previous public API private. + // This has been done to suppress warnings when building. @deprecated("incrementing input metrics is for internal use only", "2.0.0") def incBytesRead(v: Long): Unit = _bytesRead.add(v) + private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v) @deprecated("incrementing input metrics is for internal use only", "2.0.0") def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e2ebd7f00d0d..805cd9fe1f63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -260,7 +260,7 @@ class HadoopRDD[K, V]( finished = true } if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -292,7 +292,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e71d3405c0ea..f23da39eb90d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -188,7 +188,7 @@ class NewHadoopRDD[K, V]( } havePair = false if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -219,7 +219,7 @@ class NewHadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a2487eeb0483..38e6478d80f0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -811,8 +811,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) val inputMetrics = metrics.registerInputMetrics(readMethod) - inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) } // Updated blocks diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 9703b16c86f9..3605150b3b76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -214,7 +214,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } havePair = false if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -246,7 +246,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) From a64831124c215f56f124747fa241560c70cf0a36 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Feb 2016 19:32:41 -0800 Subject: [PATCH 120/131] [SPARK-13079][SQL] Extend and implement InMemoryCatalog This is a step towards consolidating `SQLContext` and `HiveContext`. This patch extends the existing Catalog API added in #10982 to include methods for handling table partitions. In particular, a partition is identified by `PartitionSpec`, which is just a `Map[String, String]`. The Catalog is still not used by anything yet, but its API is now more or less complete and an implementation is fully tested. About 200 lines are test code. Author: Andrew Or Closes #11069 from andrewor14/catalog. --- .../catalyst/catalog/InMemoryCatalog.scala | 129 ++++++++--- .../sql/catalyst/catalog/interface.scala | 40 +++- .../catalyst/catalog/CatalogTestCases.scala | 206 +++++++++++++++++- 3 files changed, 328 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9e6dfb7e9506..38be61c52a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,9 +28,10 @@ import org.apache.spark.sql.AnalysisException * All public methods should be synchronized for thread-safety. */ class InMemoryCatalog extends Catalog { + import Catalog._ private class TableDesc(var table: Table) { - val partitions = new mutable.HashMap[String, TablePartition] + val partitions = new mutable.HashMap[PartitionSpec, TablePartition] } private class DatabaseDesc(var db: Database) { @@ -46,13 +47,20 @@ class InMemoryCatalog extends Catalog { } private def existsFunction(db: String, funcName: String): Boolean = { + assertDbExists(db) catalog(db).functions.contains(funcName) } private def existsTable(db: String, table: String): Boolean = { + assertDbExists(db) catalog(db).tables.contains(table) } + private def existsPartition(db: String, table: String, spec: PartitionSpec): Boolean = { + assertTableExists(db, table) + catalog(db).tables(table).partitions.contains(spec) + } + private def assertDbExists(db: String): Unit = { if (!catalog.contains(db)) { throw new AnalysisException(s"Database $db does not exist") @@ -60,16 +68,20 @@ class InMemoryCatalog extends Catalog { } private def assertFunctionExists(db: String, funcName: String): Unit = { - assertDbExists(db) if (!existsFunction(db, funcName)) { - throw new AnalysisException(s"Function $funcName does not exists in $db database") + throw new AnalysisException(s"Function $funcName does not exist in $db database") } } private def assertTableExists(db: String, table: String): Unit = { - assertDbExists(db) if (!existsTable(db, table)) { - throw new AnalysisException(s"Table $table does not exists in $db database") + throw new AnalysisException(s"Table $table does not exist in $db database") + } + } + + private def assertPartitionExists(db: String, table: String, spec: PartitionSpec): Unit = { + if (!existsPartition(db, table, spec)) { + throw new AnalysisException(s"Partition does not exist in database $db table $table: $spec") } } @@ -77,9 +89,11 @@ class InMemoryCatalog extends Catalog { // Databases // -------------------------------------------------------------------------- - override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized { + override def createDatabase( + dbDefinition: Database, + ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Database ${dbDefinition.name} already exists.") } } else { @@ -88,9 +102,9 @@ class InMemoryCatalog extends Catalog { } override def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit = synchronized { + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { if (catalog.contains(db)) { if (!cascade) { // If cascade is false, make sure the database is empty. @@ -133,11 +147,13 @@ class InMemoryCatalog extends Catalog { // Tables // -------------------------------------------------------------------------- - override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean) - : Unit = synchronized { + override def createTable( + db: String, + tableDefinition: Table, + ignoreIfExists: Boolean): Unit = synchronized { assertDbExists(db) if (existsTable(db, tableDefinition.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") } } else { @@ -145,8 +161,10 @@ class InMemoryCatalog extends Catalog { } } - override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean) - : Unit = synchronized { + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = synchronized { assertDbExists(db) if (existsTable(db, table)) { catalog(db).tables.remove(table) @@ -190,14 +208,67 @@ class InMemoryCatalog extends Catalog { // Partitions // -------------------------------------------------------------------------- - override def alterPartition(db: String, table: String, part: TablePartition) - : Unit = synchronized { - throw new UnsupportedOperationException + override def createPartitions( + db: String, + table: String, + parts: Seq[TablePartition], + ignoreIfExists: Boolean): Unit = synchronized { + assertTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfExists) { + val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } + if (dupSpecs.nonEmpty) { + val dupSpecsStr = dupSpecs.mkString("\n===\n") + throw new AnalysisException( + s"The following partitions already exist in database $db table $table:\n$dupSpecsStr") + } + } + parts.foreach { p => existingParts.put(p.spec, p) } + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[PartitionSpec], + ignoreIfNotExists: Boolean): Unit = synchronized { + assertTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfNotExists) { + val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } + if (missingSpecs.nonEmpty) { + val missingSpecsStr = missingSpecs.mkString("\n===\n") + throw new AnalysisException( + s"The following partitions do not exist in database $db table $table:\n$missingSpecsStr") + } + } + partSpecs.foreach(existingParts.remove) } - override def alterPartitions(db: String, table: String, parts: Seq[TablePartition]) - : Unit = synchronized { - throw new UnsupportedOperationException + override def alterPartition( + db: String, + table: String, + spec: Map[String, String], + newPart: TablePartition): Unit = synchronized { + assertPartitionExists(db, table, spec) + val existingParts = catalog(db).tables(table).partitions + if (spec != newPart.spec) { + // Also a change in specs; remove the old one and add the new one back + existingParts.remove(spec) + } + existingParts.put(newPart.spec, newPart) + } + + override def getPartition( + db: String, + table: String, + spec: Map[String, String]): TablePartition = synchronized { + assertPartitionExists(db, table, spec) + catalog(db).tables(table).partitions(spec) + } + + override def listPartitions(db: String, table: String): Seq[TablePartition] = synchronized { + assertTableExists(db, table) + catalog(db).tables(table).partitions.values.toSeq } // -------------------------------------------------------------------------- @@ -205,11 +276,12 @@ class InMemoryCatalog extends Catalog { // -------------------------------------------------------------------------- override def createFunction( - db: String, func: Function, ifNotExists: Boolean): Unit = synchronized { + db: String, + func: Function, + ignoreIfExists: Boolean): Unit = synchronized { assertDbExists(db) - if (existsFunction(db, func.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Function $func already exists in $db database") } } else { @@ -222,14 +294,16 @@ class InMemoryCatalog extends Catalog { catalog(db).functions.remove(funcName) } - override def alterFunction(db: String, funcName: String, funcDefinition: Function) - : Unit = synchronized { + override def alterFunction( + db: String, + funcName: String, + funcDefinition: Function): Unit = synchronized { assertFunctionExists(db, funcName) if (funcName != funcDefinition.name) { // Also a rename; remove the old one and add the new one back catalog(db).functions.remove(funcName) } - catalog(db).functions.put(funcName, funcDefinition) + catalog(db).functions.put(funcDefinition.name, funcDefinition) } override def getFunction(db: String, funcName: String): Function = synchronized { @@ -239,7 +313,6 @@ class InMemoryCatalog extends Catalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { assertDbExists(db) - val regex = pattern.replaceAll("\\*", ".*").r filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index a6caf91f3304..b4d7dd2f4e31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -29,17 +29,15 @@ import org.apache.spark.sql.AnalysisException * Implementations should throw [[AnalysisException]] when table or database don't exist. */ abstract class Catalog { + import Catalog._ // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit + def createDatabase(dbDefinition: Database, ignoreIfExists: Boolean): Unit - def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit def alterDatabase(db: String, dbDefinition: Database): Unit @@ -71,11 +69,28 @@ abstract class Catalog { // Partitions // -------------------------------------------------------------------------- - // TODO: need more functions for partitioning. + def createPartitions( + db: String, + table: String, + parts: Seq[TablePartition], + ignoreIfExists: Boolean): Unit - def alterPartition(db: String, table: String, part: TablePartition): Unit + def dropPartitions( + db: String, + table: String, + parts: Seq[PartitionSpec], + ignoreIfNotExists: Boolean): Unit - def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit + def alterPartition( + db: String, + table: String, + spec: PartitionSpec, + newPart: TablePartition): Unit + + def getPartition(db: String, table: String, spec: PartitionSpec): TablePartition + + // TODO: support listing by pattern + def listPartitions(db: String, table: String): Seq[TablePartition] // -------------------------------------------------------------------------- // Functions @@ -132,11 +147,11 @@ case class Column( /** * A partition (Hive style) defined in the catalog. * - * @param values values for the partition columns + * @param spec partition spec values indexed by column name * @param storage storage format of the partition */ case class TablePartition( - values: Seq[String], + spec: Catalog.PartitionSpec, storage: StorageFormat ) @@ -176,3 +191,8 @@ case class Database( locationUri: String, properties: Map[String, String] ) + + +object Catalog { + type PartitionSpec = Map[String, String] +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index ab9d5ac8a20e..0d8434323fcb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -27,6 +27,11 @@ import org.apache.spark.sql.AnalysisException * Implementations of the [[Catalog]] interface can create test suites by extending this. */ abstract class CatalogTestCases extends SparkFunSuite { + private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map.empty[String, String]) + private val part1 = TablePartition(Map[String, String]("a" -> "1"), storageFormat) + private val part2 = TablePartition(Map[String, String]("b" -> "2"), storageFormat) + private val part3 = TablePartition(Map[String, String]("c" -> "3"), storageFormat) + private val funcClass = "org.apache.spark.myFunc" protected def newEmptyCatalog(): Catalog @@ -41,16 +46,16 @@ abstract class CatalogTestCases extends SparkFunSuite { */ private def newBasicCatalog(): Catalog = { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb("db1"), ifNotExists = false) - catalog.createDatabase(newDb("db2"), ifNotExists = false) - + catalog.createDatabase(newDb("db1"), ignoreIfExists = false) + catalog.createDatabase(newDb("db2"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) catalog } - private def newFunc(): Function = Function("funcname", "org.apache.spark.MyFunc") + private def newFunc(): Function = Function("funcname", funcClass) private def newDb(name: String = "default"): Database = Database(name, name + " description", "uri", Map.empty) @@ -59,7 +64,7 @@ abstract class CatalogTestCases extends SparkFunSuite { Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, None, None) - private def newFunc(name: String): Function = Function(name, "class.name") + private def newFunc(name: String): Function = Function(name, funcClass) // -------------------------------------------------------------------------- // Databases @@ -67,10 +72,10 @@ abstract class CatalogTestCases extends SparkFunSuite { test("basic create, drop and list databases") { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb(), ifNotExists = false) + catalog.createDatabase(newDb(), ignoreIfExists = false) assert(catalog.listDatabases().toSet == Set("default")) - catalog.createDatabase(newDb("default2"), ifNotExists = false) + catalog.createDatabase(newDb("default2"), ignoreIfExists = false) assert(catalog.listDatabases().toSet == Set("default", "default2")) } @@ -253,11 +258,194 @@ abstract class CatalogTestCases extends SparkFunSuite { // Partitions // -------------------------------------------------------------------------- - // TODO: Add tests cases for partitions + test("basic create and list partitions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable("mydb", newTable("mytbl"), ignoreIfExists = false) + catalog.createPartitions("mydb", "mytbl", Seq(part1, part2), ignoreIfExists = false) + assert(catalog.listPartitions("mydb", "mytbl").toSet == Set(part1, part2)) + } + + test("create partitions when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("does_not_exist", "tbl1", Seq(), ignoreIfExists = false) + } + intercept[AnalysisException] { + catalog.createPartitions("db2", "does_not_exist", Seq(), ignoreIfExists = false) + } + } + + test("create partitions that already exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = false) + } + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) + } + + test("drop partitions") { + val catalog = newBasicCatalog() + assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part2)) + val catalog2 = newBasicCatalog() + assert(catalog2.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + assert(catalog2.listPartitions("db2", "tbl2").isEmpty) + } + + test("drop partitions when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + } + } + + test("drop partitions that do not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + } + catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + } + + test("get partition") { + val catalog = newBasicCatalog() + assert(catalog.getPartition("db2", "tbl2", part1.spec) == part1) + assert(catalog.getPartition("db2", "tbl2", part2.spec) == part2) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl1", part3.spec) + } + } + + test("get partition when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getPartition("does_not_exist", "tbl1", part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition("db2", "does_not_exist", part1.spec) + } + } + + test("alter partitions") { + val catalog = newBasicCatalog() + val partSameSpec = part1.copy(storage = storageFormat.copy(serde = "myserde")) + val partNewSpec = part1.copy(spec = Map("x" -> "10")) + // alter but keep spec the same + catalog.alterPartition("db2", "tbl2", part1.spec, partSameSpec) + assert(catalog.getPartition("db2", "tbl2", part1.spec) == partSameSpec) + // alter and change spec + catalog.alterPartition("db2", "tbl2", part1.spec, partNewSpec) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl2", part1.spec) + } + assert(catalog.getPartition("db2", "tbl2", partNewSpec.spec) == partNewSpec) + } + + test("alter partition when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterPartition("does_not_exist", "tbl1", part1.spec, part1) + } + intercept[AnalysisException] { + catalog.alterPartition("db2", "does_not_exist", part1.spec, part1) + } + } // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- - // TODO: Add tests cases for functions + test("basic create and list functions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction("mydb", newFunc("myfunc"), ignoreIfExists = false) + assert(catalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + } + + test("create function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("does_not_exist", newFunc(), ignoreIfExists = false) + } + } + + test("create function that already exists") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + } + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = true) + } + + test("drop function") { + val catalog = newBasicCatalog() + assert(catalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction("db2", "func1") + assert(catalog.listFunctions("db2", "*").isEmpty) + } + + test("drop function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("does_not_exist", "something") + } + } + + test("drop function that does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("db2", "does_not_exist") + } + } + + test("get function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1") == newFunc("func1")) + intercept[AnalysisException] { + catalog.getFunction("db2", "does_not_exist") + } + } + + test("get function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getFunction("does_not_exist", "func1") + } + } + + test("alter function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1").className == funcClass) + // alter func but keep name + catalog.alterFunction("db2", "func1", newFunc("func1").copy(className = "muhaha")) + assert(catalog.getFunction("db2", "func1").className == "muhaha") + // alter func and change name + catalog.alterFunction("db2", "func1", newFunc("funcky")) + intercept[AnalysisException] { + catalog.getFunction("db2", "func1") + } + assert(catalog.getFunction("db2", "funcky").className == funcClass) + } + + test("alter function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterFunction("does_not_exist", "func1", newFunc()) + } + } + + test("list functions") { + val catalog = newBasicCatalog() + catalog.createFunction("db2", newFunc("func2"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("not_me"), ignoreIfExists = false) + assert(catalog.listFunctions("db2", "*").toSet == Set("func1", "func2", "not_me")) + assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2")) + } + } From 0f81318ae217346c20894572795e1a9cee2ebc8f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 3 Feb 2016 21:05:53 -0800 Subject: [PATCH 121/131] [SPARK-12828][SQL] add natural join support Jira: https://issues.apache.org/jira/browse/SPARK-12828 Author: Daoyuan Wang Closes #10762 from adrian-wang/naturaljoin. --- .../sql/catalyst/parser/FromClauseParser.g | 23 +++-- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 2 + .../sql/catalyst/parser/SparkSqlParser.g | 4 + .../spark/sql/catalyst/CatalystQl.scala | 4 + .../sql/catalyst/analysis/Analyzer.scala | 43 +++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../spark/sql/catalyst/plans/joinTypes.scala | 4 + .../plans/logical/basicOperators.scala | 10 ++- .../analysis/ResolveNaturalJoinSuite.scala | 90 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 24 +++++ 11 files changed, 198 insertions(+), 11 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index 6d76afcd4ac0..e83f8a7cd1b5 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -117,15 +117,20 @@ joinToken @init { gParent.pushMsg("join type specifier", state); } @after { gParent.popMsg(state); } : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN + | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN + | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN + | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN ; lateralView diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 1d07a27353dc..fd1ad59207e3 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -335,6 +335,8 @@ KW_CACHE: 'CACHE'; KW_UNCACHE: 'UNCACHE'; KW_DFS: 'DFS'; +KW_NATURAL: 'NATURAL'; + // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 6591f6b0f56c..9935678ca2ca 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN; TOK_FULLOUTERJOIN; TOK_UNIQUEJOIN; TOK_CROSSJOIN; +TOK_NATURALJOIN; +TOK_NATURALLEFTOUTERJOIN; +TOK_NATURALRIGHTOUTERJOIN; +TOK_NATURALFULLOUTERJOIN; TOK_LOAD; TOK_EXPORT; TOK_IMPORT; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 7ce2407913ad..a42360d5629f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -520,6 +520,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTSEMIJOIN" => LeftSemi case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + case "TOK_NATURALJOIN" => NaturalJoin(Inner) + case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) + case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) + case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) } Join(nodeToRelation(relation1), nodeToRelation(relation2), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a983dc1cdfeb..b30ed5928fd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -81,6 +82,7 @@ class Analyzer( ResolveAliases :: ResolveWindowOrder :: ResolveWindowFrame :: + ResolveNaturalJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -1230,6 +1232,47 @@ class Analyzer( } } } + + /** + * Removes natural joins by calculating output columns based on output from two sides, + * Then apply a Project on a normal Join to eliminate natural join. + */ + object ResolveNaturalJoin extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // Should not skip unresolved nodes because natural join is always unresolved. + case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + // find common column names from both sides, should be treated like usingColumns + val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) + val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) + val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions + val newCondition = (condition ++ joinPairs.map { + case (l, r) => EqualTo(l, r) + }).reduceLeftOption(And) + // columns not in joinPairs + val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) + val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + // we should only keep unique columns(depends on joinType) for joinCols + val projectList = joinType match { + case LeftOuter => + leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case RightOuter => + rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + // in full outer join, joinCols should be non-null if there is. + val joinedCols = joinPairs.map { + case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() + } + joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + rUniqueOutput.map(_.withNullability(true)) + case _ => + rightKeys ++ lUniqueOutput ++ rUniqueOutput + } + // use Project to trim unnecessary fields + Project(projectList, Join(left, right, joinType, newCondition)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f156b5d10acc..4ecee7504824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -905,6 +905,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case FullOuter => f // DO Nothing for Full Outer Join + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } // push down the join filter into sub query scanning if applicable @@ -939,6 +940,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index a5f6764aef7c..b10f1e63a73e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -60,3 +60,7 @@ case object FullOuter extends JoinType { case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } + +case class NaturalJoin(tpe: JoinType) extends JoinType { + override def sql: String = "NATURAL " + tpe.sql +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8150ff843476..03a79520cbd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -250,12 +250,20 @@ case class Join( def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. - override lazy val resolved: Boolean = { + // NaturalJoin should be ready for resolution only if everything else is resolved here + lazy val resolvedExceptNatural: Boolean = { childrenResolved && expressions.forall(_.resolved) && duplicateResolved && condition.forall(_.dataType == BooleanType) } + + // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need + // to eliminate natural before we mark it resolved. + override lazy val resolved: Boolean = joinType match { + case NaturalJoin(_) => false + case _ => resolvedExceptNatural + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala new file mode 100644 index 000000000000..a6554fbc414b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +class ResolveNaturalJoinSuite extends AnalysisTest { + lazy val a = 'a.string + lazy val b = 'b.string + lazy val c = 'c.string + lazy val aNotNull = a.notNull + lazy val bNotNull = b.notNull + lazy val cNotNull = c.notNull + lazy val r1 = LocalRelation(a, b) + lazy val r2 = LocalRelation(a, c) + lazy val r3 = LocalRelation(aNotNull, bNotNull) + lazy val r4 = LocalRelation(bNotNull, cNotNull) + + test("natural inner join") { + val plan = r1.join(r2, NaturalJoin(Inner), None) + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural left join") { + val plan = r1.join(r2, NaturalJoin(LeftOuter), None) + val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural right join") { + val plan = r1.join(r2, NaturalJoin(RightOuter), None) + val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural full outer join") { + val plan = r1.join(r2, NaturalJoin(FullOuter), None) + val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select( + Alias(Coalesce(Seq(a, a)), "a")(), b, c) + checkAnalysis(plan, expected) + } + + test("natural inner join with no nullability") { + val plan = r3.join(r4, NaturalJoin(Inner), None) + val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, cNotNull) + checkAnalysis(plan, expected) + } + + test("natural left join with no nullability") { + val plan = r3.join(r4, NaturalJoin(LeftOuter), None) + val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, c) + checkAnalysis(plan, expected) + } + + test("natural right join with no nullability") { + val plan = r3.join(r4, NaturalJoin(RightOuter), None) + val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, a, cNotNull) + checkAnalysis(plan, expected) + } + + test("natural full outer join with no nullability") { + val plan = r3.join(r4, NaturalJoin(FullOuter), None) + val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( + Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) + checkAnalysis(plan, expected) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 84203bbfef66..f15b926bd27c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -474,6 +474,7 @@ class DataFrame private[sql]( val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) Alias(Coalesce(Seq(leftCol, rightCol)), col)() } + case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.") } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 79bfd4b44b70..8ef7b61314a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2075,4 +2075,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("natural join") { + val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") + val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") + withTempTable("nt1", "nt2") { + df1.registerTempTable("nt1") + df2.registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } + } } From c2c956bcd1a75fd01868ee9ad2939a6d3de52bc2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 3 Feb 2016 21:19:44 -0800 Subject: [PATCH 122/131] [ML][DOC] fix wrong api link in ml onevsrest minor fix for api link in ml onevsrest Author: Yuhao Yang Closes #11068 from hhbyyh/onevsrestDoc. --- docs/ml-classification-regression.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 8ffc997b4bf5..9569a06472cb 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -289,7 +289,7 @@ The example below demonstrates how to load the
        -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) for more details. {% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %}
        From d39087147ff1052b623cdba69ffbde28b266745f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 Feb 2016 23:17:51 -0800 Subject: [PATCH 123/131] [SPARK-13113] [CORE] Remove unnecessary bit operation when decoding page number JIRA: https://issues.apache.org/jira/browse/SPARK-13113 As we shift bits right, looks like the bitwise AND operation is unnecessary. Author: Liang-Chi Hsieh Closes #11002 from viirya/improve-decodepagenumber. --- .../main/java/org/apache/spark/memory/TaskMemoryManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d31eb449eb82..d2a88864f7ac 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -312,7 +312,7 @@ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) @VisibleForTesting public static int decodePageNumber(long pagePlusOffsetAddress) { - return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + return (int) (pagePlusOffsetAddress >>> OFFSET_BITS); } private static long decodeOffset(long pagePlusOffsetAddress) { From dee801adb78d6abd0abbf76b4dfa71aa296b4f0b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Feb 2016 23:43:48 -0800 Subject: [PATCH 124/131] [SPARK-12828][SQL] Natural join follow-up This is a small addendum to #10762 to make the code more robust again future changes. Author: Reynold Xin Closes #11070 from rxin/SPARK-12828-natural-join. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../spark/sql/catalyst/plans/joinTypes.scala | 2 ++ .../analysis/ResolveNaturalJoinSuite.scala | 6 +++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b30ed5928fd5..b59eb12419c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1239,21 +1239,23 @@ class Analyzer( */ object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Should not skip unresolved nodes because natural join is always unresolved. case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => - // find common column names from both sides, should be treated like usingColumns + // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) - }).reduceLeftOption(And) + }).reduceOption(And) + // columns not in joinPairs val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) - // we should only keep unique columns(depends on joinType) for joinCols + + // the output list looks like: join keys, columns from left, columns from right val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) @@ -1261,13 +1263,14 @@ class Analyzer( rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => // in full outer join, joinCols should be non-null if there is. - val joinedCols = joinPairs.map { - case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() - } - joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } + joinedCols ++ + lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) - case _ => + case Inner => rightKeys ++ lUniqueOutput ++ rUniqueOutput + case _ => + sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields Project(projectList, Join(left, right, joinType, newCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index b10f1e63a73e..27a75326eba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -62,5 +62,7 @@ case object LeftSemi extends JoinType { } case class NaturalJoin(tpe: JoinType) extends JoinType { + require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), + "Unsupported natural join type " + tpe) override def sql: String = "NATURAL " + tpe.sql } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index a6554fbc414b..fcf4ac1967a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest { lazy val aNotNull = a.notNull lazy val bNotNull = b.notNull lazy val cNotNull = c.notNull - lazy val r1 = LocalRelation(a, b) - lazy val r2 = LocalRelation(a, c) + lazy val r1 = LocalRelation(b, a) + lazy val r2 = LocalRelation(c, a) lazy val r3 = LocalRelation(aNotNull, bNotNull) - lazy val r4 = LocalRelation(bNotNull, cNotNull) + lazy val r4 = LocalRelation(cNotNull, bNotNull) test("natural inner join") { val plan = r1.join(r2, NaturalJoin(Inner), None) From 2eaeafe8a2aa31be9b230b8d53d3baccd32535b1 Mon Sep 17 00:00:00 2001 From: Charles Allen Date: Thu, 4 Feb 2016 10:27:25 -0800 Subject: [PATCH 125/131] [SPARK-12330][MESOS] Fix mesos coarse mode cleanup In the current implementation the mesos coarse scheduler does not wait for the mesos tasks to complete before ending the driver. This causes a race where the task has to finish cleaning up before the mesos driver terminates it with a SIGINT (and SIGKILL after 3 seconds if the SIGINT doesn't work). This PR causes the mesos coarse scheduler to wait for the mesos tasks to finish (with a timeout defined by `spark.mesos.coarse.shutdown.ms`) This PR also fixes a regression caused by [SPARK-10987] whereby submitting a shutdown causes a race between the local shutdown procedure and the notification of the scheduler driver disconnection. If the scheduler driver disconnection wins the race, the coarse executor incorrectly exits with status 1 (instead of the proper status 0) With this patch the mesos coarse scheduler terminates properly, the executors clean up, and the tasks are reported as `FINISHED` in the Mesos console (as opposed to `KILLED` in < 1.6 or `FAILED` in 1.6 and later) Author: Charles Allen Closes #10319 from drcrallen/SPARK-12330. --- .../CoarseGrainedExecutorBackend.scala | 8 +++- .../mesos/CoarseMesosSchedulerBackend.scala | 39 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 136cf4a84d38..3b5cb18da1b2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.util.{Failure, Success} @@ -42,6 +43,7 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -102,19 +104,23 @@ private[spark] class CoarseGrainedExecutorBackend( } case StopExecutor => + stopping.set(true) logInfo("Driver commanded a shutdown") // Cannot shutdown here because an ack may need to be sent back to the caller. So send // a message to self to actually do the shutdown. self.send(Shutdown) case Shutdown => + stopping.set(true) executor.stop() stop() rpcEnv.shutdown() } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (driver.exists(_.address == remoteAddress)) { + if (stopping.get()) { + logInfo(s"Driver from $remoteAddress disconnected during shutdown") + } else if (driver.exists(_.address == remoteAddress)) { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2f095b86c69e..722293bb7a53 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -19,11 +19,13 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{Collections, List => JList} +import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.base.Stopwatch import com.google.common.collect.HashBiMap import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} @@ -60,6 +62,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdown.ms", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdown.ms must be >= 0") + + // Synchronization protected by stateLock + private[this] var stopCalled: Boolean = false + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -245,6 +253,13 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { + if (stopCalled) { + logDebug("Ignoring offers during shutdown") + // Driver should simply return a stopped status on race + // condition between this.stop() and completing here + offers.asScala.map(_.getId).foreach(d.declineOffer) + return + } val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers.asScala) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -364,7 +379,29 @@ private[spark] class CoarseMesosSchedulerBackend( } override def stop() { - super.stop() + // Make sure we're not launching tasks during shutdown + stateLock.synchronized { + if (stopCalled) { + logWarning("Stop called multiple times, ignoring") + return + } + stopCalled = true + super.stop() + } + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. + // See SPARK-12330 + val stopwatch = new Stopwatch() + stopwatch.start() + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent + while (slaveIdsWithExecutors.nonEmpty && + stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { + Thread.sleep(100) + } + if (slaveIdsWithExecutors.nonEmpty) { + logWarning(s"Timed out waiting for ${slaveIdsWithExecutors.size} remaining executors " + + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + + "on the mesos nodes.") + } if (mesosDriver != null) { mesosDriver.stop() } From 62a7c28388539e6fc7d16ee3009f2cf79d8635bd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 4 Feb 2016 10:29:38 -0800 Subject: [PATCH 126/131] [SPARK-13164][CORE] Replace deprecated synchronized buffer in core Building with scala 2.11 results in the warning trait SynchronizedBuffer in package mutable is deprecated: Synchronization via traits is deprecated as it is inherently unreliable. Consider java.util.concurrent.ConcurrentLinkedQueue as an alternative. Investigation shows we are already using ConcurrentLinkedQueue in other locations so switch our uses of SynchronizedBuffer to ConcurrentLinkedQueue. Author: Holden Karau Closes #11059 from holdenk/SPARK-13164-replace-deprecated-synchronized-buffer-in-core. --- .../org/apache/spark/ContextCleaner.scala | 26 +++++++++---------- .../spark/deploy/client/AppClientSuite.scala | 20 +++++++------- .../org/apache/spark/rpc/RpcEnvSuite.scala | 23 ++++++++-------- .../apache/spark/util/EventLoopSuite.scala | 10 +++---- 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5a42299a0bf8..17014e4954f9 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,9 +18,9 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} @@ -57,13 +57,11 @@ private class CleanupTaskWeakReference( */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] - with SynchronizedBuffer[CleanupTaskWeakReference] + private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]() private val referenceQueue = new ReferenceQueue[AnyRef] - private val listeners = new ArrayBuffer[CleanerListener] - with SynchronizedBuffer[CleanerListener] + private val listeners = new ConcurrentLinkedQueue[CleanerListener]() private val cleaningThread = new Thread() { override def run() { keepCleaning() }} @@ -111,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener): Unit = { - listeners += listener + listeners.add(listener) } /** Start the cleaner. */ @@ -166,7 +164,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) } /** Keep cleaning RDD, shuffle, and broadcast state. */ @@ -179,7 +177,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { synchronized { reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get + referenceBuffer.remove(reference.get) task match { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) @@ -206,7 +204,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) - listeners.foreach(_.rddCleaned(rddId)) + listeners.asScala.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { case e: Exception => logError("Error cleaning RDD " + rddId, e) @@ -219,7 +217,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) blockManagerMaster.removeShuffle(shuffleId, blocking) - listeners.foreach(_.shuffleCleaned(shuffleId)) + listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { case e: Exception => logError("Error cleaning shuffle " + shuffleId, e) @@ -231,7 +229,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) - listeners.foreach(_.broadcastCleaned(broadcastId)) + listeners.asScala.foreach(_.broadcastCleaned(broadcastId)) logDebug(s"Cleaned broadcast $broadcastId") } catch { case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) @@ -243,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning accumulator " + accId) Accumulators.remove(accId) - listeners.foreach(_.accumCleaned(accId)) + listeners.asScala.foreach(_.accumCleaned(accId)) logInfo("Cleaned accumulator " + accId) } catch { case e: Exception => logError("Error cleaning accumulator " + accId, e) @@ -258,7 +256,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning rdd checkpoint data " + rddId) ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) - listeners.foreach(_.checkpointCleaned(rddId)) + listeners.asScala.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index eb794b6739d5..658779360b7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.deploy.client -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfterAll @@ -165,14 +167,14 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd /** Application Listener to collect events */ private class AppClientCollector extends AppClientListener with Logging { - val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val connectedIdList = new ConcurrentLinkedQueue[String]() @volatile var disconnectedCount: Int = 0 - val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val deadReasonList = new ConcurrentLinkedQueue[String]() + val execAddedList = new ConcurrentLinkedQueue[String]() + val execRemovedList = new ConcurrentLinkedQueue[String]() def connected(id: String): Unit = { - connectedIdList += id + connectedIdList.add(id) } def disconnected(): Unit = { @@ -182,7 +184,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } def dead(reason: String): Unit = { - deadReasonList += reason + deadReasonList.add(reason) } def executorAdded( @@ -191,11 +193,11 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd hostPort: String, cores: Int, memory: Int): Unit = { - execAddedList += id + execAddedList.add(id) } def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { - execRemovedList += id + execRemovedList.add(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6f4eda8b47dd..22048003882d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.rpc import java.io.{File, NotSerializableException} import java.nio.charset.StandardCharsets.UTF_8 import java.util.UUID -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeoutException, TimeUnit} import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -490,30 +491,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. - * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( _env: RpcEnv, - name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = { + val events = new ConcurrentLinkedQueue[(Any, Any)] val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => - case m => events += "receive" -> m + case m => events.add("receive" -> m) } override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress + events.add("onConnected" -> remoteAddress) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + events.add("onDisconnected" -> remoteAddress) } override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress + events.add("onNetworkError" -> remoteAddress) } }) @@ -560,7 +561,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) } clientEnv.shutdown() @@ -568,8 +569,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) - assert(events.map(_._1).contains("onDisconnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onDisconnected")) } } finally { clientEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index b207d497f33c..6f7dddd4f760 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { - val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { override def onReceive(event: Int): Unit = { - buffer += event + buffer.add(event) } override def onError(e: Throwable): Unit = {} @@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts { eventLoop.start() (1 to 100).foreach(eventLoop.post) eventually(timeout(5 seconds), interval(5 millis)) { - assert((1 to 100) === buffer.toSeq) + assert((1 to 100) === buffer.asScala.toSeq) } eventLoop.stop() } From 4120bcbaffe92da40486b469334119ed12199f4f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 10:32:16 -0800 Subject: [PATCH 127/131] [SPARK-13162] Standalone mode does not respect initial executors Currently the Master would always set an application's initial executor limit to infinity. If the user specified `spark.dynamicAllocation.initialExecutors`, the config would not take effect. This is similar to #11047 but for standalone mode. Author: Andrew Or Closes #11054 from andrewor14/standalone-da-initial. --- .../spark/ExecutorAllocationManager.scala | 2 ++ .../spark/deploy/ApplicationDescription.scala | 3 +++ .../spark/deploy/master/ApplicationInfo.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 16 ++++++++++++---- .../StandaloneDynamicAllocationSuite.scala | 17 ++++++++++++++++- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 3431fc13dcb4..db143d7341ce 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -231,6 +231,8 @@ private[spark] class ExecutorAllocationManager( } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) + + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 78bbd5c03f4a..c5c5c60923f4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -29,6 +29,9 @@ private[spark] case class ApplicationDescription( // short name of compression codec used when writing event logs, if any (e.g. lzf) eventLogCodec: Option[String] = None, coresPerExecutor: Option[Int] = None, + // number of executors this application wants to start with, + // only used if dynamic allocation is enabled + initialExecutorLimit: Option[Int] = None, user: String = System.getProperty("user.name", "")) { override def toString: String = "ApplicationDescription(" + name + ")" diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 7e2cf956c725..4ffb5283e99a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -65,7 +65,7 @@ private[spark] class ApplicationInfo( appSource = new ApplicationSource(this) nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] - executorLimit = Integer.MAX_VALUE + executorLimit = desc.initialExecutorLimit.getOrElse(Integer.MAX_VALUE) appUIUrlAtHistoryServer = None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 16f33163789a..d209645610c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,11 +19,11 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -89,8 +89,16 @@ private[spark] class SparkDeploySchedulerBackend( args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, - command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) + // If we're using dynamic allocation, set our initial executor limit to 0 for now. + // ExecutorAllocationManager will send the real initial limit to the Master later. + val initialExecutorLimit = + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index fdada0777f9a..b7ff5c9e8c0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -447,7 +447,23 @@ class StandaloneDynamicAllocationSuite apps = getApplications() // kill executor successfully assert(apps.head.executors.size === 1) + } + test("initial executor limit") { + val initialExecutorLimit = 1 + val myConf = appConf + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString) + sc = new SparkContext(myConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === initialExecutorLimit) + assert(apps.head.getExecutorLimit === initialExecutorLimit) + } } // =============================== @@ -540,7 +556,6 @@ class StandaloneDynamicAllocationSuite val missingExecutors = masterExecutors.toSet.diff(driverExecutors.toSet).toSeq.sorted missingExecutors.foreach { id => // Fake an executor registration so the driver knows about us - val port = System.currentTimeMillis % 65536 val endpointRef = mock(classOf[RpcEndpointRef]) val mockAddress = mock(classOf[RpcAddress]) when(endpointRef.address).thenReturn(mockAddress) From 15205da817b24ef0e349ec24d84034dc30b501f8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 10:34:43 -0800 Subject: [PATCH 128/131] [SPARK-13053][TEST] Unignore tests in InternalAccumulatorSuite These were ignored because they are incorrectly written; they don't actually trigger stage retries, which is what the tests are testing. These tests are now rewritten to induce stage retries through fetch failures. Note: there were 2 tests before and now there's only 1. What happened? It turns out that the case where we only resubmit a subset of of the original missing partitions is very difficult to simulate in tests without potentially introducing flakiness. This is because the `DAGScheduler` removes all map outputs associated with a given executor when this happens, and we will need multiple executors to trigger this case, and sometimes the scheduler still removes map outputs from all executors. Author: Andrew Or Closes #10969 from andrewor14/unignore-accum-test. --- .../org/apache/spark/AccumulatorSuite.scala | 52 +++++-- .../spark/InternalAccumulatorSuite.scala | 128 +++++++++--------- 2 files changed, 102 insertions(+), 78 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index b8f2b96d7088..e0fdd4597385 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -323,35 +323,60 @@ private[spark] object AccumulatorSuite { * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. */ private class SaveInfoListener extends SparkListener { - private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] - private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] - private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + type StageId = Int + type StageAttemptId = Int - // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated + private val completedStageInfos = new ArrayBuffer[StageInfo] + private val completedTaskInfos = + new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]] + + // Callback to call when a job completes. Parameter is job ID. @GuardedBy("this") + private var jobCompletionCallback: () => Unit = null + private var calledJobCompletionCallback: Boolean = false private var exception: Throwable = null def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq - def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq + def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] = + completedTaskInfos.get((stageId, stageAttemptId)).getOrElse(Seq.empty[TaskInfo]) - /** Register a callback to be called on job end. */ - def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { - jobCompletionCallback = callback + /** + * If `jobCompletionCallback` is set, block until the next call has finished. + * If the callback failed with an exception, throw it. + */ + def awaitNextJobCompletion(): Unit = synchronized { + if (jobCompletionCallback != null) { + while (!calledJobCompletionCallback) { + wait() + } + calledJobCompletionCallback = false + if (exception != null) { + exception = null + throw exception + } + } } - /** Throw a stored exception, if any. */ - def maybeThrowException(): Unit = synchronized { - if (exception != null) { throw exception } + /** + * Register a callback to be called on job end. + * A call to this should be followed by [[awaitNextJobCompletion]]. + */ + def registerJobCompletionCallback(callback: () => Unit): Unit = { + jobCompletionCallback = callback } override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { if (jobCompletionCallback != null) { try { - jobCompletionCallback(jobEnd.jobId) + jobCompletionCallback() } catch { // Store any exception thrown here so we can throw them later in the main thread. // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. case NonFatal(e) => exception = e + } finally { + calledJobCompletionCallback = true + notify() } } } @@ -361,7 +386,8 @@ private class SaveInfoListener extends SparkListener { } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - completedTaskInfos += taskEnd.taskInfo + completedTaskInfos.getOrElseUpdate( + (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 630b46f828df..44a16e26f493 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockStatus} @@ -160,7 +161,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { iter } // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => + listener.registerJobCompletionCallback { () => val stageInfos = listener.getCompletedStageInfos val taskInfos = listener.getCompletedTaskInfos assert(stageInfos.size === 1) @@ -179,6 +180,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) } rdd.count() + listener.awaitNextJobCompletion() } test("internal accumulators in multiple stages") { @@ -205,7 +207,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { iter } // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => + listener.registerJobCompletionCallback { () => // We ran 3 stages, and the accumulator values should be distinct val stageInfos = listener.getCompletedStageInfos assert(stageInfos.size === 3) @@ -220,13 +222,66 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { rdd.count() } - // TODO: these two tests are incorrect; they don't actually trigger stage retries. - ignore("internal accumulators in fully resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks - } + test("internal accumulators in resubmitted stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + + // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with + // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt. + // This should retry both stages in the scheduler. Note that we only want to fail the + // first stage attempt because we want the stage to eventually succeed. + val x = sc.parallelize(1 to 100, numPartitions) + .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter } + .groupBy(identity) + val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId + val rdd = x.mapPartitionsWithIndex { case (i, iter) => + // Fail the first stage attempt. Here we use the task attempt ID to determine this. + // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt + // ID that's < 2 * numPartitions belongs to the first attempt of this stage. + val taskContext = TaskContext.get() + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + if (isFirstStageAttempt) { + throw new FetchFailedException( + SparkEnv.get.blockManager.blockManagerId, + sid, + taskContext.partitionId(), + taskContext.partitionId(), + "simulated fetch failure") + } else { + iter + } + } - ignore("internal accumulators in partially resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { () => + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried + val mapStageId = stageInfos.head.stageId + val mapStageInfo1stAttempt = stageInfos.head + val mapStageInfo2ndAttempt = { + stageInfos.tail.find(_.stageId == mapStageId).getOrElse { + fail("expected two attempts of the same shuffle map stage.") + } + } + val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values) + val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values) + // Both map stages should have succeeded, since the fetch failure happened in the + // result stage, not the map stage. This means we should get the accumulator updates + // from all partitions. + assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions) + assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions) + // Because this test resubmitted the map stage with all missing partitions, we should have + // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is + // the case by comparing the accumulator IDs between the two attempts. + // Note: it would be good to also test the case where the map stage is resubmitted where + // only a subset of the original partitions are missing. However, this scenario is very + // difficult to construct without potentially introducing flakiness. + assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id) + } + rdd.count() + listener.awaitNextJobCompletion() } test("internal accumulators are registered for cleanups") { @@ -257,63 +312,6 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } } - /** - * Test whether internal accumulators are merged properly if some tasks fail. - * TODO: make this actually retry the stage. - */ - private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { - val listener = new SaveInfoListener - val numPartitions = 10 - val numFailedPartitions = (0 until numPartitions).count(failCondition) - // This says use 1 core and retry tasks up to 2 times - sc = new SparkContext("local[1, 2]", "test") - sc.addSparkListener(listener) - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => - val taskContext = TaskContext.get() - taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1 - // Fail the first attempts of a subset of the tasks - if (failCondition(i) && taskContext.attemptNumber() == 0) { - throw new Exception("Failing a task intentionally.") - } - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findTestAccum(stageInfos.head.accumulables.values) - // If all partitions failed, then we would resubmit the whole stage again and create a - // fresh set of internal accumulators. Otherwise, these internal accumulators do count - // failed values, so we must include the failed values. - val expectedAccumValue = - if (numPartitions == numFailedPartitions) { - numPartitions - } else { - numPartitions + numFailedPartitions - } - assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findTestAccum(taskInfo.accumulables) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.asInstanceOf[Long] === 1L) - assert(taskAccum.value.isDefined) - Some(taskAccum.value.get.asInstanceOf[Long]) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None - } - } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) - } - rdd.count() - listener.maybeThrowException() - } - /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ From 085f510ae554e2739a38ee0bc7210c4ece902f3f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 11:07:06 -0800 Subject: [PATCH 129/131] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #7971 (requested by yhuai) Closes #8539 (requested by srowen) Closes #8746 (requested by yhuai) Closes #9288 (requested by andrewor14) Closes #9321 (requested by andrewor14) Closes #9935 (requested by JoshRosen) Closes #10442 (requested by andrewor14) Closes #10585 (requested by srowen) Closes #10785 (requested by srowen) Closes #10832 (requested by andrewor14) Closes #10941 (requested by marmbrus) Closes #11024 (requested by andrewor14) From 33212cb9a13a6012b4c19ccfc0fb3db75de304da Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 4 Feb 2016 11:08:50 -0800 Subject: [PATCH 130/131] [SPARK-13168][SQL] Collapse adjacent repartition operators Spark SQL should collapse adjacent `Repartition` operators and only keep the last one. Author: Josh Rosen Closes #11064 from JoshRosen/collapse-repartition. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 16 ++++++++++++++-- ...ingSuite.scala => CollapseProjectSuite.scala} | 4 ++-- .../catalyst/optimizer/FilterPushdownSuite.scala | 2 +- .../sql/catalyst/optimizer/JoinOrderSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 +++++++++++++-- .../org/apache/spark/sql/hive/SQLBuilder.scala | 4 ++-- 6 files changed, 33 insertions(+), 10 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ProjectCollapsingSuite.scala => CollapseProjectSuite.scala} (96%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ecee7504824..a1ac93073916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,7 +68,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughAggregate, ColumnPruning, // Operator combine - ProjectCollapsing, + CollapseRepartition, + CollapseProject, CombineFilters, CombineLimits, CombineUnions, @@ -322,7 +323,7 @@ object ColumnPruning extends Rule[LogicalPlan] { * Combines two adjacent [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression. */ -object ProjectCollapsing extends Rule[LogicalPlan] { +object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p @ Project(projectList1, Project(projectList2, child)) => @@ -390,6 +391,16 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Combines adjacent [[Repartition]] operators by keeping only the last one. + */ +object CollapseRepartition extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) => + Repartition(numPartitions, shuffle, child) + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given @@ -857,6 +868,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. + * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala similarity index 96% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 85b6530481b0..f5fd5ca6beb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -25,11 +25,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ProjectCollapsingSuite extends PlanTest { +class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: - Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + Batch("CollapseProject", Once, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index f9f3bd55aa57..b49ca928b629 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -42,7 +42,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala index 9b1e16c72764..858a0d8fde3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -43,7 +43,7 @@ class JoinOrderSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 8fca5e2167d0..adaeb513bc1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -223,6 +222,18 @@ class PlannerSuite extends SharedSQLContext { } } + test("collapse adjacent repartitions") { + val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) + def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length + assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + doubleRepartitioned.queryExecution.optimizedPlan match { + case r: Repartition => + assert(r.numPartitions === 5) + assert(r.shuffle === false) + } + } + // --- Unit tests of EnsureRequirements --------------------------------------------------------- // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 165459453836..fc5725d6915e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -188,7 +188,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - ProjectCollapsing, + CollapseProject, // Used to handle other auxiliary `Project`s added by analyzer (e.g. // `ResolveAggregateFunctions` rule) From ecad77a6ac85892f1155f596e84729342e484088 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Tue, 19 Jan 2016 14:24:58 -0800 Subject: [PATCH 131/131] Support multiple executors per node on Mesos. Support spark.executor.cores on Mesos. --- .../CoarseGrainedSchedulerBackend.scala | 11 +- .../mesos/CoarseMesosSchedulerBackend.scala | 375 +++++++++++------- .../cluster/mesos/MesosSchedulerBackend.scala | 4 +- .../cluster/mesos/MesosSchedulerUtils.scala | 10 +- .../CoarseMesosSchedulerBackendSuite.scala | 365 ++++++++++++----- .../mesos/MesosSchedulerBackendSuite.scala | 2 +- .../mesos/MesosSchedulerUtilsSuite.scala | 6 +- docs/configuration.md | 15 +- docs/running-on-mesos.md | 8 +- 9 files changed, 521 insertions(+), 275 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f69a3d371e5d..0a5b09dc0d1f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -240,6 +240,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK + + logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + s"${executorData.executorHost}.") + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } @@ -309,7 +313,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + driverEndpoint = createDriverEndpointRef(properties) + } + + protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = { + rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) } protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 722293bb7a53..999d38b29bfe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -23,17 +23,17 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashMap, HashSet} import com.google.common.base.Stopwatch -import com.google.common.collect.HashBiMap import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} +import org.apache.spark.rpc.{RpcEndpointAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -73,17 +73,13 @@ private[spark] class CoarseMesosSchedulerBackend( private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID - val coresByTaskId = new HashMap[Int, Int] + val coresByTaskId = new HashMap[String, Int] var totalCoresAcquired = 0 - val slaveIdsWithExecutors = new HashSet[String] - - // Maping from slave Id to hostname - private val slaveIdToHost = new HashMap[String, String] - - val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] - // How many times tasks on each slave failed - val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + // SlaveID -> Slave + // This map accumulates entries for the duration of the job. Slaves are never deleted, because + // we need to maintain e.g. failure state and connection state. + private val slaves = new HashMap[String, Slave] /** * The total number of executors we aim to have. Undefined when not using dynamic allocation @@ -97,13 +93,11 @@ private[spark] class CoarseMesosSchedulerBackend( */ private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) - private val pendingRemovedSlaveIds = new HashSet[String] - // private lock object protecting mutable state above. Using the intrinsic lock // may lead to deadlocks since the superclass might also try to lock private val stateLock = new ReentrantLock - val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) // Offer constraints private val slaveOfferConstraints = @@ -113,27 +107,31 @@ private[spark] class CoarseMesosSchedulerBackend( private val rejectOfferDurationForUnmetConstraints = getRejectOfferDurationForUnmetConstraints(sc) - // A client for talking to the external shuffle service, if it is a + // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { - Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf, "shuffle"), - securityManager, - securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled())) + Some(getShuffleClient()) } else { None } } + protected def getShuffleClient(): MesosExternalShuffleClient = { + new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf, "shuffle"), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) + } + var nextMesosTaskId = 0 @volatile var appId: String = _ - def newMesosTaskId(): Int = { + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 - id + id.toString } override def start() { @@ -148,7 +146,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -192,7 +190,7 @@ private[spark] class CoarseMesosSchedulerBackend( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + s" --driver-url $driverURL" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -200,12 +198,11 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head - val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + s" --driver-url $driverURL" + - s" --executor-id $executorId" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -260,113 +257,209 @@ private[spark] class CoarseMesosSchedulerBackend( offers.asScala.map(_.getId).foreach(d.declineOffer) return } - val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers.asScala) { + + logDebug(s"Received ${offers.size} resource offers.") + + val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => val offerAttributes = toAttributeMap(offer.getAttributesList) - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + } + + declineUnmatchedOffers(d, unmatchedOffers) + handleMatchedOffers(d, matchedOffers) + } + } + + private def declineUnmatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + for (offer <- offers) { + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val filters = Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build() + + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + + d.declineOffer(offer.getId, filters) + } + } + + /** + * Launches executors on accepted offers, and declines unused offers. Executors are launched + * round-robin on offers. + * + * @param d SchedulerDriver + * @param offers Mesos offers that match attribute constraints + */ + private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + val tasks = buildMesosTasks(offers) + for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val offerMem = getResource(offer.getResourcesList, "mem") + val offerCpus = getResource(offer.getResourcesList, "cpus") + val id = offer.getId.getValue + + if (tasks.contains(offer.getId)) { // accept + val offerTasks = tasks(offer.getId) + + logDebug(s"Accepting offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus. Launching ${offerTasks.size} Mesos tasks.") + + for (task <- offerTasks) { + val taskId = task.getTaskId + val mem = getResource(task.getResourcesList, "mem") + val cpus = getResource(task.getResourcesList, "cpus") + + logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus.") + } + + d.launchTasks( + Collections.singleton(offer.getId), + offerTasks.asJava) + } else { // decline + logDebug(s"Declining offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus") + + d.declineOffer(offer.getId) + } + } + } + + /** + * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize + * per-task memory and IO, tasks are round-robin assigned to offers. + * + * @param offers Mesos offers that match attribute constraints + * @return A map from OfferID to a list of Mesos tasks to launch on that offer + */ + private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = { + // offerID -> tasks + val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil) + + // offerID -> resources + val remainingResources = mutable.Map(offers.map(offer => + (offer.getId.getValue, offer.getResourcesList)): _*) + + var launchTasks = true + + // TODO(mgummelt): combine offers for a single slave + // + // round-robin create executors on the available offers + while (launchTasks) { + launchTasks = false + + for (offer <- offers) { val slaveId = offer.getSlaveId.getValue - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus").toInt - val id = offer.getId.getValue - if (meetsConstraints) { - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse.asJava) - .addAllResources(memResourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) - } - - // Accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) - } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + val offerId = offer.getId.getValue + val resources = remainingResources(offerId) + + if (canLaunchTask(slaveId, resources)) { + // Create a task + launchTasks = true + val taskId = newMesosTaskId() + val offerCPUs = getResource(resources, "cpus").toInt + + val taskCPUs = executorCores(offerCPUs) + val taskMemory = executorMemory(sc) + + slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) + + val (afterCPUResources, cpuResourcesToUse) = + partitionResources(resources, "cpus", taskCPUs) + val (resourcesLeft, memResourcesToUse) = + partitionResources(afterCPUResources.asJava, "mem", taskMemory) + + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder) } - } else { - // This offer does not meet constraints. We don't need to see it again. - // Decline the offer for a long period of time. - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" - + s" for $rejectOfferDurationForUnmetConstraints seconds") - d.declineOffer(offer.getId, Filters.newBuilder() - .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + + tasks(offer.getId) ::= taskBuilder.build() + remainingResources(offerId) = resourcesLeft.asJava + totalCoresAcquired += taskCPUs + coresByTaskId(taskId) = taskCPUs } } } + tasks.toMap + } + + private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + val offerMem = getResource(resources, "mem") + val offerCPUs = getResource(resources, "cpus").toInt + val cpus = executorCores(offerCPUs) + val mem = executorMemory(sc) + + cpus > 0 && + cpus <= offerCPUs && + cpus + totalCoresAcquired <= maxCores && + mem <= offerMem && + numExecutors() < executorLimit && + slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES } + private def executorCores(offerCPUs: Int): Int = { + sc.conf.getInt("spark.executor.cores", + math.min(offerCPUs, maxCores - totalCoresAcquired)) + } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val taskId = status.getTaskId.getValue.toInt - val state = status.getState - logInfo(s"Mesos task $taskId is now $state") - val slaveId: String = status.getSlaveId.getValue + val taskId = status.getTaskId.getValue + val slaveId = status.getSlaveId.getValue + val state = TaskState.fromMesos(status.getState) + + logInfo(s"Mesos task $taskId is now ${status.getState}") + stateLock.synchronized { + val slave = slaves(slaveId) + // If the shuffle service is enabled, have the driver register with each one of the // shuffle services. This allows the shuffle services to clean up state associated with // this application when the driver exits. There is currently not a great way to detect // this through Mesos, since the shuffle services are set up independently. - if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && - slaveIdToHost.contains(slaveId) && - shuffleServiceEnabled) { + if (state.equals(TaskState.RUNNING) && + shuffleServiceEnabled && + !slave.shuffleRegistered) { assume(mesosExternalShuffleClient.isDefined, "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) - val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + - s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get - .registerDriverWithShuffleService(hostname, externalShufflePort) + .registerDriverWithShuffleService(slave.hostname, externalShufflePort) + slave.shuffleRegistered = true } - if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId.get(taskId) - slaveIdsWithExecutors -= slaveId - taskIdToSlaveId.remove(taskId) + if (TaskState.isFinished(state)) { // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (TaskState.isFailed(TaskState.fromMesos(state))) { - failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 - if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { + if (TaskState.isFailed(state)) { + slave.taskFailures += 1 + + if (slave.taskFailures >= MAX_SLAVE_FAILURES) { logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } - executorTerminated(d, slaveId, s"Executor finished with state $state") + executorTerminated(d, slaveId, taskId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node d.reviveOffers() } @@ -388,20 +481,24 @@ private[spark] class CoarseMesosSchedulerBackend( stopCalled = true super.stop() } + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. // See SPARK-12330 val stopwatch = new Stopwatch() stopwatch.start() + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent - while (slaveIdsWithExecutors.nonEmpty && + while (numExecutors() > 0 && stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { Thread.sleep(100) } - if (slaveIdsWithExecutors.nonEmpty) { - logWarning(s"Timed out waiting for ${slaveIdsWithExecutors.size} remaining executors " + + if (numExecutors() > 0) { + logWarning(s"Timed out waiting for ${numExecutors()} remaining executors " + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + "on the mesos nodes.") } + if (mesosDriver != null) { mesosDriver.stop() } @@ -410,40 +507,25 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} /** - * Called when a slave is lost or a Mesos task finished. Update local view on - * what tasks are running and remove the terminated slave from the list of pending - * slave IDs that we might have asked to be killed. It also notifies the driver - * that an executor was removed. + * Called when a slave is lost or a Mesos task finished. Updates local view on + * what tasks are running. It also notifies the driver that an executor was removed. */ - private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + private def executorTerminated(d: SchedulerDriver, + slaveId: String, + taskId: String, + reason: String): Unit = { stateLock.synchronized { - if (slaveIdsWithExecutors.contains(slaveId)) { - val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.containsKey(slaveId)) { - val taskId: Int = slaveIdToTaskId.get(slaveId) - taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) - } - // TODO: This assumes one Spark executor per Mesos slave, - // which may no longer be true after SPARK-5095 - pendingRemovedSlaveIds -= slaveId - slaveIdsWithExecutors -= slaveId - } + removeExecutor(taskId, SlaveLost(reason)) + slaves(slaveId).taskIDs.remove(taskId) } } - private def sparkExecutorId(slaveId: String, taskId: String): String = { - s"$slaveId/$taskId" - } - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { logInfo(s"Mesos slave lost: ${slaveId.getValue}") - executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) } override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) + logInfo("Mesos executor lost: %s".format(e.getValue)) } override def applicationId(): String = @@ -463,23 +545,26 @@ private[spark] class CoarseMesosSchedulerBackend( override def doKillExecutors(executorIds: Seq[String]): Boolean = { if (mesosDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") - return false - } - - val slaveIdToTaskId = taskIdToSlaveId.inverse() - for (executorId <- executorIds) { - val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.containsKey(slaveId)) { - mesosDriver.killTask( - TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) - pendingRemovedSlaveIds += slaveId - } else { - logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + false + } else { + for (executorId <- executorIds) { + val taskId = TaskID.newBuilder().setValue(executorId).build() + mesosDriver.killTask(taskId) } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true } - // no need to adjust `executorLimitOption` since the AllocationManager already communicated - // the desired limit through a call to `doRequestTotalExecutors`. - // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] - true } + + private def numExecutors(): Int = { + slaves.values.map(_.taskIDs.size).sum + } +} + +private class Slave(val hostname: String) { + val taskIDs = new HashSet[String]() + var taskFailures = 0 + var shuffleRegistered = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 340f29bac921..8929d8a42778 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -138,7 +138,7 @@ private[spark] class MesosSchedulerBackend( val (resourcesAfterCpu, usedCpuResources) = partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = - partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) + partitionResources(resourcesAfterCpu.asJava, "mem", executorMemory(sc)) builder.addAllResources(usedCpuResources.asJava) builder.addAllResources(usedMemResources.asJava) @@ -250,7 +250,7 @@ private[spark] class MesosSchedulerBackend( // check offers for // 1. Memory requirements // 2. CPU requirements - need at least 1 for executor, 1 for task - val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsMemoryRequirements = mem >= executorMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) val meetsRequirements = (meetsMemoryRequirements && meetsCPURequirements) || diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index f9f5da9bc8df..a98f2f1fe5da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -140,15 +140,15 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } } - /** - * Signal that the scheduler has registered with Mesos. - */ - protected def getResource(res: JList[Resource], name: String): Double = { + def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } + /** + * Signal that the scheduler has registered with Mesos. + */ protected def markRegistered(): Unit = { registerLatch.countDown() } @@ -337,7 +337,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ - def calculateTotalMemory(sc: SparkContext): Int = { + def executorMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + sc.executorMemory diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index a4110d2d462d..e542aa0cfc4d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -17,19 +17,23 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util import java.util.Collections +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.Scalar -import org.mockito.Matchers +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient +import org.apache.spark.rpc.{RpcEndpointRef} import org.apache.spark.scheduler.TaskSchedulerImpl class CoarseMesosSchedulerBackendSuite extends SparkFunSuite @@ -37,6 +41,223 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with BeforeAndAfter { + var sparkConf: SparkConf = _ + var driver: SchedulerDriver = _ + var taskScheduler: TaskSchedulerImpl = _ + var backend: CoarseMesosSchedulerBackend = _ + var externalShuffleClient: MesosExternalShuffleClient = _ + var driverEndpoint: RpcEndpointRef = _ + + test("mesos supports killing and limiting executors") { + setBackend() + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val minMem = backend.executorMemory(sc) + val minCpu = 4 + val offers = List((minMem, minCpu)) + + // launches a task on a valid offer + offerResources(offers) + verifyTaskLaunched("o1") + + // kills executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("0"))) + val taskID0 = createTaskId("0") + verify(driver, times(1)).killTask(taskID0) + + // doesn't launch a new task when requested executors == 0 + offerResources(offers, 2) + verifyDeclinedOffer(driver, createOfferId("o2")) + + // Launches a new task when requested executors is positive + backend.doRequestTotalExecutors(2) + offerResources(offers, 2) + verifyTaskLaunched("o2") + } + + test("mesos supports killing and relaunching tasks with executors") { + setBackend() + + // launches a task on a valid offer + val minMem = backend.executorMemory(sc) + 1024 + val minCpu = 4 + val offer1 = (minMem, minCpu) + val offer2 = (minMem, 1) + offerResources(List(offer1, offer2)) + verifyTaskLaunched("o1") + + // accounts for a killed task + val status = createTaskStatus("0", "s1", TaskState.TASK_KILLED) + backend.statusUpdate(driver, status) + verify(driver, times(1)).reviveOffers() + + // Launches a new task on a valid offer from the same slave + offerResources(List(offer2)) + verifyTaskLaunched("o2") + } + + test("mesos supports spark.executor.cores") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + val executorMemory = backend.executorMemory(sc) + val offers = List((executorMemory * 2, executorCores + 1)) + offerResources(offers) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == executorCores) + } + + test("mesos supports unset spark.executor.cores") { + setBackend() + + val executorMemory = backend.executorMemory(sc) + val offerCores = 10 + offerResources(List((executorMemory * 2, offerCores))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == offerCores) + } + + test("mesos does not acquire more than spark.cores.max") { + val maxCores = 10 + setBackend(Map("spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory, maxCores + 1))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == maxCores) + } + + test("mesos declines offers that violate attribute constraints") { + setBackend(Map("spark.mesos.constraints" -> "x:true")) + offerResources(List((backend.executorMemory(sc), 4))) + verifyDeclinedOffer(driver, createOfferId("o1"), true) + } + + test("mesos assigns tasks round-robin on offers") { + val executorCores = 4 + val maxCores = executorCores * 2 + setBackend(Map("spark.executor.cores" -> executorCores.toString, + "spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List( + (executorMemory * 2, executorCores * 2), + (executorMemory * 2, executorCores * 2))) + + verifyTaskLaunched("o1") + verifyTaskLaunched("o2") + } + + test("mesos creates multiple executors on a single slave") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + // offer with room for two executors + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory * 2, executorCores * 2))) + + // verify two executors were started on a single offer + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 2) + } + + test("mesos doesn't register twice with the same shuffle service") { + setBackend(Map("spark.shuffle.service.enabled" -> "true")) + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + val offer2 = createOffer("o2", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer2).asJava) + verifyTaskLaunched("o2") + + val status1 = createTaskStatus("0", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status1) + + val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status2) + verify(externalShuffleClient, times(1)).registerDriverWithShuffleService(anyString, anyInt) + } + + test("mesos kills an executor when told") { + setBackend() + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + backend.doKillExecutors(List("0")) + verify(driver, times(1)).killTask(createTaskId("0")) + } + + private def verifyDeclinedOffer(driver: SchedulerDriver, + offerId: OfferID, + filter: Boolean = false): Unit = { + if (filter) { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId), anyObject[Filters]) + } else { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId)) + } + } + + private def offerResources(offers: List[(Int, Int)], startId: Int = 1): Unit = { + val mesosOffers = offers.zipWithIndex.map {case (offer, i) => + createOffer(s"o${i + startId}", s"s${i + startId}", offer._1, offer._2)} + + backend.resourceOffers(driver, mesosOffers.asJava) + } + + private def verifyTaskLaunched(offerId: String): java.util.Collection[TaskInfo] = { + val captor = ArgumentCaptor.forClass(classOf[java.util.Collection[TaskInfo]]) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + captor.capture()) + captor.getValue + } + + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { + TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setState(state) + .build + } + + + private def createOfferId(offerId: String): OfferID = { + OfferID.newBuilder().setValue(offerId).build() + } + + private def createSlaveId(slaveId: String): SlaveID = { + SlaveID.newBuilder().setValue(slaveId).build() + } + + private def createExecutorId(executorId: String): ExecutorID = { + ExecutorID.newBuilder().setValue(executorId).build() + } + + private def createTaskId(taskId: String): TaskID = { + TaskID.newBuilder().setValue(taskId).build() + } + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() @@ -47,8 +268,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder() - .setValue(offerId).build()) + builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() .setValue("f1")) .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) @@ -58,130 +278,55 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, - driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + driver: SchedulerDriver, + shuffleClient: MesosExternalShuffleClient, + endpoint: RpcEndpointRef): CoarseMesosSchedulerBackend = { val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( - masterUrl: String, - scheduler: Scheduler, - sparkUser: String, - appName: String, - conf: SparkConf, - webuiUrl: Option[String] = None, - checkpoint: Option[Boolean] = None, - failoverTimeout: Option[Double] = None, - frameworkId: Option[String] = None): SchedulerDriver = driver + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver + + override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient + + override protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint + markRegistered() } backend.start() backend } - var sparkConf: SparkConf = _ - - before { + private def setBackend(sparkConfVars: Map[String, String] = null) { sparkConf = (new SparkConf) .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") - sc = new SparkContext(sparkConf) - } - - test("mesos supports killing and limiting executors") { - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) - - sparkConf.set("spark.driver.host", "driverHost") - sparkConf.set("spark.driver.port", "1234") - - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) - - val taskID0 = TaskID.newBuilder().setValue("0").build() - - backend.resourceOffers(driver, mesosOffers) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) - - // simulate the allocation manager down-scaling executors - backend.doRequestTotalExecutors(0) - assert(backend.doKillExecutors(Seq("s1/0"))) - verify(driver, times(1)).killTask(taskID0) - - val mesosOffers2 = new java.util.ArrayList[Offer] - mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) - backend.resourceOffers(driver, mesosOffers2) - - verify(driver, times(1)) - .declineOffer(OfferID.newBuilder().setValue("o2").build()) - - // Verify we didn't launch any new executor - assert(backend.slaveIdsWithExecutors.size === 1) - - backend.doRequestTotalExecutors(2) - backend.resourceOffers(driver, mesosOffers2) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) + if (sparkConfVars != null) { + for (attr <- sparkConfVars) { + sparkConf.set(attr._1, attr._2) + } + } - assert(backend.slaveIdsWithExecutors.size === 2) - backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) - assert(backend.slaveIdsWithExecutors.size === 1) - } + sc = new SparkContext(sparkConf) - test("mesos supports killing and relaunching tasks with executors") { - val driver = mock[SchedulerDriver] + driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] + taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + externalShuffleClient = mock[MesosExternalShuffleClient] + driverEndpoint = mock[RpcEndpointRef] - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) + 1024 - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - val offer1 = createOffer("o1", "s1", minMem, minCpu) - mesosOffers.add(offer1) - - val offer2 = createOffer("o2", "s1", minMem, 1); - - backend.resourceOffers(driver, mesosOffers) - - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer1.getId)), - anyObject(), - anyObject[Filters]) - - // Simulate task killed, executor no longer running - val status = TaskStatus.newBuilder() - .setTaskId(TaskID.newBuilder().setValue("0").build()) - .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) - .setState(TaskState.TASK_KILLED) - .build - - backend.statusUpdate(driver, status) - assert(!backend.slaveIdsWithExecutors.contains("s1")) - - mesosOffers.clear() - mesosOffers.add(offer2) - backend.resourceOffers(driver, mesosOffers) - assert(backend.slaveIdsWithExecutors.contains("s1")) - - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer2.getId)), - anyObject(), - anyObject[Filters]) - - verify(driver, times(1)).reviveOffers() + backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index e111e2e9f616..3fb3279073f2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -189,7 +189,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val minMem = backend.calculateTotalMemory(sc) + val minMem = backend.executorMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 2eb43b731338..85437b2f8081 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -41,20 +41,20 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("use at-least minimum overhead") { val f = fixture when(f.sc.executorMemory).thenReturn(512) - utils.calculateTotalMemory(f.sc) shouldBe 896 + utils.executorMemory(f.sc) shouldBe 896 } test("use overhead if it is greater than minimum value") { val f = fixture when(f.sc.executorMemory).thenReturn(4096) - utils.calculateTotalMemory(f.sc) shouldBe 4505 + utils.executorMemory(f.sc) shouldBe 4505 } test("use spark.mesos.executor.memoryOverhead (if set)") { val f = fixture when(f.sc.executorMemory).thenReturn(1024) f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") - utils.calculateTotalMemory(f.sc) shouldBe 1536 + utils.executorMemory(f.sc) shouldBe 1536 } test("parse a non-empty constraint string correctly") { diff --git a/docs/configuration.md b/docs/configuration.md index 93b399d819cc..517813450585 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -825,13 +825,18 @@ Apart from these, the following properties are also available, and may be useful spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - The number of cores to use on each executor. For YARN and standalone mode only. + 1 in YARN mode, all the available cores on the worker in + standalone and Mesos coarse-grained modes. + + + The number of cores to use on each executor. - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one - executor per application will run on each worker. + In standalone and Mesos coarse-grained modes, setting this + parameter allows an application to run multiple executors on the + same worker, provided that there are enough cores on that + worker. Otherwise, only one executor per application will run on + each worker. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 0ef1ccb36e11..0eff06f6a46d 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -280,9 +280,11 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores 0 - Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. - The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Set the extra number of cores for an executor to advertise. This + does not result in more cores allocated. It instead means that an + executor will "pretend" it has more cores, so that the driver will + send it more tasks. Use this to increase parallelism. This + setting is only used for Mesos coarse-grained mode.