From 85a845f7549fb40f85587036ca49cafc9064b1d5 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 25 Jan 2016 14:58:45 +0900 Subject: [PATCH 01/21] Skip unnecessary final group-by when input data already clustered --- .../spark/sql/execution/SparkStrategies.scala | 9 +++ .../sql/execution/aggregate/AggUtils.scala | 55 +++++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 18 +++--- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4aaf454285f4f..8647fbbf2910c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -271,10 +271,19 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { planLater(child)) } } else if (functionsWithDistinct.isEmpty) { + // Check if the child operator satisfies the group-by distribution requirements + val childPlan = planLater(child) + val skipUnnecessaryAggregate = if (groupingExpressions != Nil) { + childPlan.outputPartitioning.satisfies(ClusteredDistribution(groupingExpressions)) + } else { + false + } + aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, + skipUnnecessaryAggregate, planLater(child)) } else { aggregate.AggUtils.planAggregateWithOneDistinct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9bf..f4deb375a30a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -81,20 +81,38 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + skipUnnecessaryAggregate: Boolean, child: SparkPlan): Seq[SparkPlan] = { // Check if we can use HashAggregate. - // 1. Create an Aggregate Operator for partial aggregations. - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialResultExpressions = - groupingAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - - val partialAggregate = createAggregate( + + if (skipUnnecessaryAggregate) { + // A single-stage aggregation is enough to get the final result because input data are + // already clustered. + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + child = child + ) :: Nil + } else { + // 1. Create an Aggregate Operator for partial aggregations. + + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + val partialAggregate = createAggregate( requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, aggregateExpressions = partialAggregateExpressions, @@ -103,13 +121,13 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) - // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - val finalAggregate = createAggregate( + val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, @@ -118,7 +136,8 @@ object AggUtils { resultExpressions = resultExpressions, child = partialAggregate) - finalAggregate :: Nil + finalAggregate :: Nil + } } def planAggregateWithOneDistinct( @@ -232,7 +251,7 @@ object AggUtils { // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute (expr, attr) - }.unzip + }.unzip createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 499f3180379c2..e1419cb129aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1248,17 +1248,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } /** - * Verifies that there is no Exchange between the Aggregations for `df` + * Verifies that there is a single Aggregation for `df` */ - private def verifyNonExchangingAgg(df: DataFrame) = { + private def verifyNonExchangingSingleAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: HashAggregateExec => - atFirstAgg = !atFirstAgg - case _ => + case agg: HashAggregateExec => { if (atFirstAgg) { - fail("Should not have operators between the two aggregations") + fail("Should not have back to back Aggregates") } + atFirstAgg = true + } + case _ => } } @@ -1292,9 +1293,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Group by the column we are distributed by. This should generate a plan with no exchange // between the aggregates val df3 = testData.repartition($"key").groupBy("key").count() - verifyNonExchangingAgg(df3) - verifyNonExchangingAgg(testData.repartition($"key", $"value") + verifyNonExchangingSingleAgg(df3) + verifyNonExchangingSingleAgg(testData.repartition($"key", $"value") .groupBy("key", "value").count()) + verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count()) // Grouping by just the first distributeBy expr, need to exchange. verifyExchangingAgg(testData.repartition($"key", $"value") From f3659b7d60b012ae432b87608325b01fe136912f Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 5 Feb 2016 16:37:34 +0900 Subject: [PATCH 02/21] Rename an argument in planAggregateWithoutDistinct --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 4 ++-- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8647fbbf2910c..55f8fdd7b4914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -273,7 +273,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else if (functionsWithDistinct.isEmpty) { // Check if the child operator satisfies the group-by distribution requirements val childPlan = planLater(child) - val skipUnnecessaryAggregate = if (groupingExpressions != Nil) { + val canPatialAggregate = if (groupingExpressions != Nil) { childPlan.outputPartitioning.satisfies(ClusteredDistribution(groupingExpressions)) } else { false @@ -283,7 +283,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingExpressions, aggregateExpressions, resultExpressions, - skipUnnecessaryAggregate, + canPatialAggregate, planLater(child)) } else { aggregate.AggUtils.planAggregateWithOneDistinct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index f4deb375a30a3..9daef122d4a54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -81,13 +81,13 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], - skipUnnecessaryAggregate: Boolean, + partialAggregation: Boolean, child: SparkPlan): Seq[SparkPlan] = { // Check if we can use HashAggregate. val groupingAttributes = groupingExpressions.map(_.toAttribute) - if (skipUnnecessaryAggregate) { + if (partialAggregation) { // A single-stage aggregation is enough to get the final result because input data are // already clustered. val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) From c4242ea6051eea69dc4de74063525d299a9c18b3 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 25 Apr 2016 15:12:23 +0900 Subject: [PATCH 03/21] Fix style errors --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e1419cb129aa0..cd485770d269c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1253,12 +1253,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNonExchangingSingleAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: HashAggregateExec => { + case agg: HashAggregateExec => if (atFirstAgg) { fail("Should not have back to back Aggregates") } atFirstAgg = true - } case _ => } } From a5addebb5d01c46e7bb614e04202b9ffd2a7803d Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Tue, 7 Jun 2016 01:46:53 -0700 Subject: [PATCH 04/21] Apply comments --- .../apache/spark/sql/execution/SparkStrategies.scala | 9 --------- .../spark/sql/execution/aggregate/AggUtils.scala | 12 +++++++++--- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 55f8fdd7b4914..4aaf454285f4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -271,19 +271,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { planLater(child)) } } else if (functionsWithDistinct.isEmpty) { - // Check if the child operator satisfies the group-by distribution requirements - val childPlan = planLater(child) - val canPatialAggregate = if (groupingExpressions != Nil) { - childPlan.outputPartitioning.satisfies(ClusteredDistribution(groupingExpressions)) - } else { - false - } - aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, - canPatialAggregate, planLater(child)) } else { aggregate.AggUtils.planAggregateWithOneDistinct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 9daef122d4a54..b34bba525eac2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} @@ -81,13 +82,18 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], - partialAggregation: Boolean, child: SparkPlan): Seq[SparkPlan] = { - // Check if we can use HashAggregate. + // Check if the child operator satisfies the group-by distribution requirements + val skipPartialAggregation = if (groupingExpressions != Nil) { + child.outputPartitioning.satisfies(ClusteredDistribution(groupingExpressions)) + } else { + false + } + // Check if we can use HashAggregate. val groupingAttributes = groupingExpressions.map(_.toAttribute) - if (partialAggregation) { + if (skipPartialAggregation) { // A single-stage aggregation is enough to get the final result because input data are // already clustered. val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) From 74fa277b660cedd534440df9cd592f8fe899eab5 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Tue, 14 Jun 2016 22:34:05 +0900 Subject: [PATCH 05/21] Re-Implement in EnsureRequirements --- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../sql/execution/aggregate/AggUtils.scala | 243 +++++++++--------- .../aggregate/SortAggregateExec.scala | 3 + .../exchange/EnsureRequirements.scala | 23 +- 4 files changed, 145 insertions(+), 126 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4aaf454285f4f..7f0059f3db28a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -264,7 +264,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sys.error("Distinct columns cannot exist in Aggregate operator containing " + "aggregate functions which don't support partial aggregation.") } else { - aggregate.AggUtils.planAggregateWithoutPartial( + aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index b34bba525eac2..86d3cd96a42db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} @@ -28,23 +27,73 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto */ object AggUtils { - def planAggregateWithoutPartial( + private[execution] def isAggregateExec(operator: SparkPlan): Boolean = { + operator.isInstanceOf[HashAggregateExec] || operator.isInstanceOf[SortAggregateExec] + } + + private def createPartialAggregate( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { - - val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - SortAggregateExec( - requiredChildDistributionExpressions = Some(groupingExpressions), + child: SparkPlan): SparkPlan = { + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + createAggregate( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateAttributes, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - child = child - ) :: Nil + resultExpressions = partialResultExpressions, + child = child) + } + + private def updateAggregateMode(mode: AggregateMode) = mode match { + case Partial => PartialMerge + case Complete => Final + case mode => mode + } + + private[execution] def addMapSideAggregate(operator: SparkPlan) + : (SparkPlan, Seq[SparkPlan]) = operator match { + case agg @ HashAggregateExec( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) => + val newChild = createPartialAggregate(groupingExpressions, aggregateExpressions, child) + val parent = agg.copy( + groupingExpressions = groupingExpressions.map(_.toAttribute), + aggregateExpressions = + aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))), + initialInputBufferOffset = groupingExpressions.length + ) + (parent, newChild :: Nil) + + case agg @ SortAggregateExec( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) => + val newChild = createPartialAggregate(groupingExpressions, aggregateExpressions, child) + val parent = agg.copy( + groupingExpressions = groupingExpressions.map(_.toAttribute), + aggregateExpressions = + aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))), + initialInputBufferOffset = groupingExpressions.length + ) + (parent, newChild :: Nil) } private def createAggregate( @@ -83,67 +132,19 @@ object AggUtils { aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - // Check if the child operator satisfies the group-by distribution requirements - val skipPartialAggregation = if (groupingExpressions != Nil) { - child.outputPartitioning.satisfies(ClusteredDistribution(groupingExpressions)) - } else { - false - } - - // Check if we can use HashAggregate. val groupingAttributes = groupingExpressions.map(_.toAttribute) + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - if (skipPartialAggregation) { - // A single-stage aggregation is enough to get the final result because input data are - // already clustered. - val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - - createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - child = child - ) :: Nil - } else { - // 1. Create an Aggregate Operator for partial aggregations. - - val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialResultExpressions = - groupingAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - - val partialAggregate = createAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = groupingExpressions, - aggregateExpressions = partialAggregateExpressions, - aggregateAttributes = partialAggregateAttributes, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - - // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - - val finalAggregate = createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - aggregateExpressions = finalAggregateExpressions, - aggregateAttributes = finalAggregateAttributes, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - - finalAggregate :: Nil - } + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + child = child + ) :: Nil } def planAggregateWithOneDistinct( @@ -167,35 +168,35 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. - val partialAggregate: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. - createAggregate( - groupingExpressions = groupingExpressions ++ namedDistinctExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - resultExpressions = groupingAttributes ++ distinctAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) - } + // val partialAggregate: SparkPlan = { + // val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + // val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // // We will group by the original grouping expression, plus an additional expression for the + // // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // // expressions will be [key, value]. + // createAggregate( + // groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + // aggregateExpressions = aggregateExpressions, + // aggregateAttributes = aggregateAttributes, + // resultExpressions = groupingAttributes ++ distinctAttributes ++ + // aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + // child = child) + // } // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes, + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) + child = child) } // 3. Create an Aggregate operator for partial aggregation (for distinct) @@ -209,35 +210,35 @@ object AggUtils { .asInstanceOf[AggregateFunction] } - val partialDistinctAggregate: SparkPlan = { - val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) - val (distinctAggregateExpressions, distinctAggregateAttributes) = - rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val expr = AggregateExpression(func, Partial, isDistinct = true) - // Use original AggregationFunction to lookup attributes, which is used to build - // aggregateFunctionToAttribute - val attr = functionsWithDistinct(i).resultAttribute - (expr, attr) - }.unzip - - val partialAggregateResult = groupingAttributes ++ - mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ - distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - createAggregate( - groupingExpressions = groupingAttributes, - aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, - aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, - resultExpressions = partialAggregateResult, - child = partialMergeAggregate) - } + // val partialDistinctAggregate: SparkPlan = { + // val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // // The attributes of the final aggregation buffer, which is presented as input to the + // // result projection: + // val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) + // val (distinctAggregateExpressions, distinctAggregateAttributes) = + // rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // // We rewrite the aggregate function to a non-distinct aggregation because + // // its input will have distinct arguments. + // // We just keep the isDistinct setting to true, so when users look at the query plan, + // // they still can see distinct aggregations. + // val expr = AggregateExpression(func, Partial, isDistinct = true) + // // Use original AggregationFunction to lookup attributes, which is used to build + // // aggregateFunctionToAttribute + // val attr = functionsWithDistinct(i).resultAttribute + // (expr, attr) + // }.unzip + + // val partialAggregateResult = groupingAttributes ++ + // mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + // distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + // createAggregate( + // groupingExpressions = groupingAttributes, + // aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + // aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + // initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + // resultExpressions = partialAggregateResult, + // child = partialMergeAggregate) + // } // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { @@ -261,12 +262,12 @@ object AggUtils { createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, + groupingExpressions = groupingExpressions, aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = partialDistinctAggregate) + child = partialMergeAggregate) } finalAndCompleteAggregate :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2a81a823c44b3..08988508733d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -121,3 +121,6 @@ case class SortAggregateExec( } } } + +object SortAggregateExec + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 446571aa8409f..bb0236a7b3726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.internal.SQLConf /** @@ -151,11 +152,25 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering - var children: Seq[SparkPlan] = operator.children - assert(requiredChildDistributions.length == children.length) - assert(requiredChildOrderings.length == children.length) + // var children: Seq[SparkPlan] = operator.children + // assert(requiredChildDistributions.length == children.length) + // assert(requiredChildOrderings.length == children.length) // Ensure that the operator's children satisfy their output distribution requirements: + val childrenWithDist = operator.children.zip(requiredChildDistributions) + + // If necessary, add map-side aggregates + var (parent, children) = if (AggUtils.isAggregateExec(operator)) { + val (child, distribution) = childrenWithDist.head + if (!child.outputPartitioning.satisfies(distribution)) { + AggUtils.addMapSideAggregate(operator) + } else { + (operator, child :: Nil) + } + } else { + (operator, operator.children) + } + children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child @@ -246,7 +261,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } - operator.withNewChildren(children) + parent.withNewChildren(children) } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { From b7470852db35ffda701d6149e30f2d8d77d3be6f Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 16 Jun 2016 11:14:02 +0900 Subject: [PATCH 06/21] Fix bugs in planStreamingAggregation --- .../sql/execution/aggregate/AggUtils.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 86d3cd96a42db..500910ae83e04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -291,20 +291,20 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregate: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. - createAggregate( - groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - resultExpressions = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) - } + // val partialAggregate: SparkPlan = { + // val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + // val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // // We will group by the original grouping expression, plus an additional expression for the + // // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // // expressions will be [key, value]. + // createAggregate( + // groupingExpressions = groupingExpressions, + // aggregateExpressions = aggregateExpressions, + // aggregateAttributes = aggregateAttributes, + // resultExpressions = groupingAttributes ++ + // aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + // child = child) + // } val partialMerged1: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -312,13 +312,13 @@ object AggUtils { createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) + child = child) } val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) From cb5172742083a8293ed303925a0d1f85340f5bb6 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 16 Jun 2016 11:24:46 +0900 Subject: [PATCH 07/21] Fix comments in AggUtils --- .../sql/execution/aggregate/AggUtils.scala | 74 ++----------------- 1 file changed, 7 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 500910ae83e04..7e9e7787724ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -168,22 +168,6 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. - // val partialAggregate: SparkPlan = { - // val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - // val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // // We will group by the original grouping expression, plus an additional expression for the - // // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // // expressions will be [key, value]. - // createAggregate( - // groupingExpressions = groupingExpressions ++ namedDistinctExpressions, - // aggregateExpressions = aggregateExpressions, - // aggregateAttributes = aggregateAttributes, - // resultExpressions = groupingAttributes ++ distinctAttributes ++ - // aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - // child = child) - // } - - // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) @@ -199,7 +183,7 @@ object AggUtils { child = child) } - // 3. Create an Aggregate operator for partial aggregation (for distinct) + // 2. Create an Aggregate operator for partial aggregation (for distinct) val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already @@ -210,37 +194,7 @@ object AggUtils { .asInstanceOf[AggregateFunction] } - // val partialDistinctAggregate: SparkPlan = { - // val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - // // The attributes of the final aggregation buffer, which is presented as input to the - // // result projection: - // val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) - // val (distinctAggregateExpressions, distinctAggregateAttributes) = - // rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => - // // We rewrite the aggregate function to a non-distinct aggregation because - // // its input will have distinct arguments. - // // We just keep the isDistinct setting to true, so when users look at the query plan, - // // they still can see distinct aggregations. - // val expr = AggregateExpression(func, Partial, isDistinct = true) - // // Use original AggregationFunction to lookup attributes, which is used to build - // // aggregateFunctionToAttribute - // val attr = functionsWithDistinct(i).resultAttribute - // (expr, attr) - // }.unzip - - // val partialAggregateResult = groupingAttributes ++ - // mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ - // distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - // createAggregate( - // groupingExpressions = groupingAttributes, - // aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, - // aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, - // initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, - // resultExpressions = partialAggregateResult, - // child = partialMergeAggregate) - // } - - // 4. Create an Aggregate Operator for the final aggregation. + // 3. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -275,13 +229,14 @@ object AggUtils { /** * Plans a streaming aggregation using the following progression: - * - Partial Aggregation - * - Shuffle - * - Partial Merge (now there is at most 1 tuple per group) + * - Partial Aggregation (now there is at most 1 tuple per group) * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) * - PartialMerge (now there is at most 1 tuple per group) * - StateStoreSave (saves the tuple for the next batch) * - Complete (output the current result of the aggregation) + * + * If the first aggregation needs a shuffle to satisfy its distribution, a map-side partial + * an aggregation and a shuffle are added in `EnsureRequirements`. */ def planStreamingAggregation( groupingExpressions: Seq[NamedExpression], @@ -291,23 +246,8 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) - // val partialAggregate: SparkPlan = { - // val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - // val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // // We will group by the original grouping expression, plus an additional expression for the - // // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // // expressions will be [key, value]. - // createAggregate( - // groupingExpressions = groupingExpressions, - // aggregateExpressions = aggregateExpressions, - // aggregateAttributes = aggregateAttributes, - // resultExpressions = groupingAttributes ++ - // aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - // child = child) - // } - val partialMerged1: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( requiredChildDistributionExpressions = From 5de4871d5670b156c1cfab54779257624eb4243c Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 16 Jun 2016 12:15:59 +0900 Subject: [PATCH 08/21] Refactor logics in EnsureRequirements --- .../sql/execution/aggregate/AggUtils.scala | 22 +++++++-------- .../exchange/EnsureRequirements.scala | 27 ++++++++++--------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 7e9e7787724ca..fb808bbbd43f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -60,7 +60,7 @@ object AggUtils { } private[execution] def addMapSideAggregate(operator: SparkPlan) - : (SparkPlan, Seq[SparkPlan]) = operator match { + : (SparkPlan, SparkPlan) = operator match { case agg @ HashAggregateExec( requiredChildDistributionExpressions, groupingExpressions, @@ -69,14 +69,14 @@ object AggUtils { initialInputBufferOffset, resultExpressions, child) => - val newChild = createPartialAggregate(groupingExpressions, aggregateExpressions, child) - val parent = agg.copy( + val mapSideAgg = createPartialAggregate(groupingExpressions, aggregateExpressions, child) + val mergeAgg = agg.copy( groupingExpressions = groupingExpressions.map(_.toAttribute), aggregateExpressions = aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))), - initialInputBufferOffset = groupingExpressions.length - ) - (parent, newChild :: Nil) + initialInputBufferOffset = groupingExpressions.length) + + (mergeAgg, mapSideAgg) case agg @ SortAggregateExec( requiredChildDistributionExpressions, @@ -86,14 +86,14 @@ object AggUtils { initialInputBufferOffset, resultExpressions, child) => - val newChild = createPartialAggregate(groupingExpressions, aggregateExpressions, child) - val parent = agg.copy( + val mapSideAgg = createPartialAggregate(groupingExpressions, aggregateExpressions, child) + val mergeAgg = agg.copy( groupingExpressions = groupingExpressions.map(_.toAttribute), aggregateExpressions = aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))), - initialInputBufferOffset = groupingExpressions.length - ) - (parent, newChild :: Nil) + initialInputBufferOffset = groupingExpressions.length) + + (mergeAgg, mapSideAgg) } private def createAggregate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index bb0236a7b3726..e76d01c74c69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -159,25 +159,28 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // Ensure that the operator's children satisfy their output distribution requirements: val childrenWithDist = operator.children.zip(requiredChildDistributions) - // If necessary, add map-side aggregates var (parent, children) = if (AggUtils.isAggregateExec(operator)) { + // If an aggregation need a shuffle to satisfy its distribution, a map-side partial an + // aggregation and a shuffle are added as children. val (child, distribution) = childrenWithDist.head if (!child.outputPartitioning.satisfies(distribution)) { - AggUtils.addMapSideAggregate(operator) + val (mergeAgg, mapSideAgg) = AggUtils.addMapSideAggregate(operator) + val newChild = ShuffleExchange( + createPartitioning(distribution, defaultNumPreShufflePartitions), mapSideAgg) + (mergeAgg, newChild :: Nil) } else { (operator, child :: Nil) } } else { - (operator, operator.children) - } - - children = children.zip(requiredChildDistributions).map { - case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - child - case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) - case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + val newChildren = childrenWithDist.map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchangeExec(mode, child) + case (child, distribution) => + ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + } + (operator, newChildren) } // If the operator has multiple children and specifies child output distributions (e.g. join), From dc6a7f237f70c137cb4e27f9e02f1706dc694d8e Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 17 Jun 2016 15:24:23 +0900 Subject: [PATCH 09/21] Fix bugs --- .../spark/sql/execution/SparkStrategies.scala | 17 +-- .../sql/execution/aggregate/AggUtils.scala | 104 ++++++++++-------- .../exchange/EnsureRequirements.scala | 38 ++++--- 3 files changed, 84 insertions(+), 75 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7f0059f3db28a..cda3b2b75e6b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -259,24 +259,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { - if (functionsWithDistinct.nonEmpty) { - sys.error("Distinct columns cannot exist in Aggregate operator containing " + - "aggregate functions which don't support partial aggregation.") - } else { - aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } - } else if (functionsWithDistinct.isEmpty) { + if (functionsWithDistinct.isEmpty) { aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index fb808bbbd43f8..fd859627625af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -27,76 +27,87 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto */ object AggUtils { - private[execution] def isAggregateExec(operator: SparkPlan): Boolean = { + private[execution] def isAggregate(operator: SparkPlan): Boolean = { operator.isInstanceOf[HashAggregateExec] || operator.isInstanceOf[SortAggregateExec] } - private def createPartialAggregate( + private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = { + assert(isAggregate(operator)) + def supportPartial(exprs: Seq[AggregateExpression]) = + exprs.map(_.aggregateFunction).forall(_.supportsPartial) + operator match { + case agg @ HashAggregateExec(_, _, aggregateExpressions, _, _, _, _) => + supportPartial(aggregateExpressions) + case agg @ SortAggregateExec(_, _, aggregateExpressions, _, _, _, _) => + supportPartial(aggregateExpressions) + } + } + + private def createPartialAggregateExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], child: SparkPlan): SparkPlan = { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val partialAggregateExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, _, false, _) if functionsWithDistinct.length > 0 => + agg.copy(mode = PartialMerge) + case agg => + agg.copy(mode = Partial) + } val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialResultExpressions = groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, aggregateExpressions = partialAggregateExpressions, aggregateAttributes = partialAggregateAttributes, - initialInputBufferOffset = 0, + initialInputBufferOffset = if (functionsWithDistinct.length > 0) { + groupingExpressions.length + functionsWithDistinct.head.aggregateFunction.children.length + } else { + 0 + }, resultExpressions = partialResultExpressions, child = child) } - private def updateAggregateMode(mode: AggregateMode) = mode match { - case Partial => PartialMerge - case Complete => Final - case mode => mode + private def updateMergeAggregateMode(aggregateExpressions: Seq[AggregateExpression]) = { + def updateMode(mode: AggregateMode) = mode match { + case Partial => PartialMerge + case Complete => Final + case mode => mode + } + aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode))) } - private[execution] def addMapSideAggregate(operator: SparkPlan) + private[execution] def createPartialAggregate(operator: SparkPlan) : (SparkPlan, SparkPlan) = operator match { - case agg @ HashAggregateExec( - requiredChildDistributionExpressions, - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - initialInputBufferOffset, - resultExpressions, - child) => - val mapSideAgg = createPartialAggregate(groupingExpressions, aggregateExpressions, child) + case agg @ HashAggregateExec(_, groupingExpressions, aggregateExpressions, _, _, _, child) => + val mapSideAgg = createPartialAggregateExec( + groupingExpressions, aggregateExpressions, child) val mergeAgg = agg.copy( groupingExpressions = groupingExpressions.map(_.toAttribute), - aggregateExpressions = - aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))), + aggregateExpressions = updateMergeAggregateMode(aggregateExpressions), initialInputBufferOffset = groupingExpressions.length) (mergeAgg, mapSideAgg) - case agg @ SortAggregateExec( - requiredChildDistributionExpressions, - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - initialInputBufferOffset, - resultExpressions, - child) => - val mapSideAgg = createPartialAggregate(groupingExpressions, aggregateExpressions, child) + case agg @ SortAggregateExec(_, groupingExpressions, aggregateExpressions, _, _, _, child) => + val mapSideAgg = createPartialAggregateExec( + groupingExpressions, aggregateExpressions, child) val mergeAgg = agg.copy( groupingExpressions = groupingExpressions.map(_.toAttribute), - aggregateExpressions = - aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))), + aggregateExpressions = updateMergeAggregateMode(aggregateExpressions), initialInputBufferOffset = groupingExpressions.length) (mergeAgg, mapSideAgg) } - private def createAggregate( + private def createAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, @@ -105,7 +116,8 @@ object AggUtils { resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { val useHash = HashAggregateExec.supportsAggregate( - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) && + aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, @@ -135,9 +147,11 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) + val supportPartial = aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) - createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), + createAggregateExec( + requiredChildDistributionExpressions = + Some(if (supportPartial) groupingAttributes else groupingExpressions), groupingExpressions = groupingExpressions, aggregateExpressions = completeAggregateExpressions, aggregateAttributes = completeAggregateAttributes, @@ -171,7 +185,7 @@ object AggUtils { val partialMergeAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), groupingExpressions = groupingExpressions ++ namedDistinctExpressions, @@ -183,7 +197,7 @@ object AggUtils { child = child) } - // 2. Create an Aggregate operator for partial aggregation (for distinct) + // 2. Create an Aggregate Operator for the final aggregation. val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already @@ -193,8 +207,6 @@ object AggUtils { aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] } - - // 3. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -207,16 +219,16 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val expr = AggregateExpression(func, Final, isDistinct = true) + val expr = AggregateExpression(func, Complete, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingExpressions, + groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, @@ -249,7 +261,7 @@ object AggUtils { val partialMerged1: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingExpressions, @@ -266,7 +278,7 @@ object AggUtils { val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, @@ -288,7 +300,7 @@ object AggUtils { // projection: val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e76d01c74c69b..1d056f8ff3276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -152,35 +152,39 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering - // var children: Seq[SparkPlan] = operator.children - // assert(requiredChildDistributions.length == children.length) - // assert(requiredChildOrderings.length == children.length) + assert(requiredChildDistributions.length == operator.children.length) + assert(requiredChildOrderings.length == operator.children.length) // Ensure that the operator's children satisfy their output distribution requirements: val childrenWithDist = operator.children.zip(requiredChildDistributions) - var (parent, children) = if (AggUtils.isAggregateExec(operator)) { - // If an aggregation need a shuffle to satisfy its distribution, a map-side partial an - // aggregation and a shuffle are added as children. - val (child, distribution) = childrenWithDist.head - if (!child.outputPartitioning.satisfies(distribution)) { - val (mergeAgg, mapSideAgg) = AggUtils.addMapSideAggregate(operator) - val newChild = ShuffleExchange( - createPartitioning(distribution, defaultNumPreShufflePartitions), mapSideAgg) - (mergeAgg, newChild :: Nil) - } else { - (operator, child :: Nil) - } - } else { + def createShuffleExchange(dist: Distribution, child: SparkPlan) = + ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child) + + var (parent, children) = if (!AggUtils.isAggregate(operator)) { val newChildren = childrenWithDist.map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + createShuffleExchange(distribution, child) } (operator, newChildren) + } else { + val (child, distribution) = childrenWithDist.head + if (!child.outputPartitioning.satisfies(distribution)) { + if (AggUtils.supportPartialAggregate(operator)) { + // If an aggregation needs a shuffle and support partial aggregations, a map-side partial + // an aggregation and a shuffle are added as children. + val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) + (mergeAgg, createShuffleExchange(distribution, mapSideAgg) :: Nil) + } else { + (operator, createShuffleExchange(distribution, child) :: Nil) + } + } else { + (operator, child :: Nil) + } } // If the operator has multiple children and specifies child output distributions (e.g. join), From 7aff3ad141332d81ecd7a06a304e36b5487cf386 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sun, 19 Jun 2016 14:15:09 +0900 Subject: [PATCH 10/21] Fix tests --- .../org/apache/spark/sql/execution/PlannerSuite.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 13490c35679a2..24c319f0da78a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -40,10 +40,9 @@ class PlannerSuite extends SharedSQLContext { private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = spark.sessionState.planner import planner._ - val plannedOption = Aggregation(query).headOption - val planned = - plannedOption.getOrElse( - fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val ensureRequirements = EnsureRequirements(spark.sessionState.conf) + val planned = Aggregation(query).headOption.map(ensureRequirements(_)) + .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } // For the new aggregation code path, there will be four aggregate operator for From 000c50120a66cb609551a9743a09dcbe2dffb102 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sun, 19 Jun 2016 19:16:15 +0900 Subject: [PATCH 11/21] Update comments --- .../sql/execution/aggregate/AggUtils.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index fd859627625af..9ceeccc2736c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -181,8 +181,8 @@ object AggUtils { val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) - // 1. Create an Aggregate Operator for partial aggregations. - val partialMergeAggregate: SparkPlan = { + // 1. Create an Aggregate Operator for non-distinct aggregations. + val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregateExec( @@ -217,8 +217,8 @@ object AggUtils { rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. + // We keep the isDistinct setting to true because this flag is used to generate partial + // aggregations and it is easy to see aggregation types in the query plan. val expr = AggregateExpression(func, Complete, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute @@ -233,7 +233,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = partialMergeAggregate) + child = partialAggregate) } finalAndCompleteAggregate :: Nil @@ -258,7 +258,7 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialMerged1: SparkPlan = { + val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregateExec( @@ -273,9 +273,9 @@ object AggUtils { child = child) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, partialAggregate) - val partialMerged2: SparkPlan = { + val partialMerged: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregateExec( @@ -292,7 +292,7 @@ object AggUtils { // Note: stateId and returnAllStates are filled in later with preparation rules // in IncrementalExecution. val saved = StateStoreSaveExec( - groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) + groupingAttributes, stateId = None, returnAllStates = None, partialMerged) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) From 2615a6e06f75c5b264c726a61cba6e647f77dd6f Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sun, 19 Jun 2016 22:43:42 +0900 Subject: [PATCH 12/21] Add tests --- .../spark/sql/execution/PlannerSuite.scala | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 24c319f0da78a..bb9a3bead1558 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.Inner @@ -37,35 +37,58 @@ class PlannerSuite extends SharedSQLContext { setupTestData() - private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = { val planner = spark.sessionState.planner import planner._ val ensureRequirements = EnsureRequirements(spark.sessionState.conf) val planned = Aggregation(query).headOption.map(ensureRequirements(_)) .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?")) - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - // For the new aggregation code path, there will be four aggregate operator for - // distinct aggregations. - assert( - aggregations.size == 2 || aggregations.size == 4, - s"The plan of query $query does not have partial aggregations.") + planned.collect { case n if n.nodeName contains "Aggregate" => n } } test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - testPartialAggregationPlan(query) + assert(testPartialAggregationPlan(query).size == 2, + s"The plan of query $query does not have partial aggregations.") } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed testPartialAggregationPlan(query) + // For the new aggregation code path, there will be four aggregate operator for distinct + // aggregations. + assert(testPartialAggregationPlan(query).size == 4, + s"The plan of query $query does not have partial aggregations.") } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - testPartialAggregationPlan(query) + // For the new aggregation code path, there will be four aggregate operator for distinct + // aggregations. + assert(testPartialAggregationPlan(query).size == 4, + s"The plan of query $query does not have partial aggregations.") + } + + test("non-partial aggregation for distinct aggregates") { + withTempTable("testNonPartialAggregation") { + val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) + val row = Row.fromSeq(Seq.fill(1)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + spark.createDataFrame(rowRDD, schema).createOrReplaceTempView("testNonPartialAggregation") + + val planned = sql( + """ + |SELECT t.value, SUM(DISTINCT t.value) + |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t + |GROUP BY t.value + """.stripMargin).queryExecution.executedPlan + + // If input data are already partitioned and the same columns are used in grouping keys and + // aggregation values, no partial aggregation exist in query plans. + val aggOps = planned.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps.size == 2, s"The plan $planned has partial aggregations.") + } } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { From 91701eb5171555ce73ba6b190b41cc17f05eee8c Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sat, 20 Aug 2016 01:41:16 +0900 Subject: [PATCH 13/21] Fix a syntax error --- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bb9a3bead1558..35d70bd2b51a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -71,7 +71,7 @@ class PlannerSuite extends SharedSQLContext { } test("non-partial aggregation for distinct aggregates") { - withTempTable("testNonPartialAggregation") { + withTempView("testNonPartialAggregation") { val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) val row = Row.fromSeq(Seq.fill(1)(null)) val rowRDD = sparkContext.parallelize(row :: Nil) From 74a14d7cb1db8fc0a5dcd18453f7151ba8edb22d Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 22 Aug 2016 19:31:46 +0900 Subject: [PATCH 14/21] Add tests --- .../aggregate/SortAggregateExec.scala | 3 --- .../exchange/EnsureRequirements.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 21 ++++++++++++------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 08988508733d6..2a81a823c44b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -121,6 +121,3 @@ case class SortAggregateExec( } } } - -object SortAggregateExec - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1d056f8ff3276..509bfb365f42d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -176,7 +176,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (!child.outputPartitioning.satisfies(distribution)) { if (AggUtils.supportPartialAggregate(operator)) { // If an aggregation needs a shuffle and support partial aggregations, a map-side partial - // an aggregation and a shuffle are added as children. + // aggregation and a shuffle are added as children. val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) (mergeAgg, createShuffleExchange(distribution, mapSideAgg) :: Nil) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 35d70bd2b51a8..436ff59c4d3f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -70,24 +70,31 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } - test("non-partial aggregation for distinct aggregates") { + test("non-partial aggregation for aggregates") { withTempView("testNonPartialAggregation") { val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) val row = Row.fromSeq(Seq.fill(1)(null)) val rowRDD = sparkContext.parallelize(row :: Nil) - spark.createDataFrame(rowRDD, schema).createOrReplaceTempView("testNonPartialAggregation") + spark.createDataFrame(rowRDD, schema).repartition($"value") + .createOrReplaceTempView("testNonPartialAggregation") - val planned = sql( + val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value") + .queryExecution.executedPlan + + // If input data are already partitioned and the same columns are used in grouping keys and + // aggregation values, no partial aggregation exist in query plans. + val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.") + + val planned2 = sql( """ |SELECT t.value, SUM(DISTINCT t.value) |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t |GROUP BY t.value """.stripMargin).queryExecution.executedPlan - // If input data are already partitioned and the same columns are used in grouping keys and - // aggregation values, no partial aggregation exist in query plans. - val aggOps = planned.collect { case n if n.nodeName contains "Aggregate" => n } - assert(aggOps.size == 2, s"The plan $planned has partial aggregations.") + val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.") } } From e37ef6afd47e9dd325a7f9e6d0826a3cb66c8e2e Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 22 Aug 2016 21:18:07 +0900 Subject: [PATCH 15/21] Add a superclass for *AggregateExec --- .../sql/execution/aggregate/AggUtils.scala | 47 ++++++--------- .../sql/execution/aggregate/Aggregate.scala | 58 +++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 22 +------ .../aggregate/SortAggregateExec.scala | 22 +------ .../exchange/EnsureRequirements.scala | 46 ++++++--------- 5 files changed, 96 insertions(+), 99 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 9ceeccc2736c7..580ce16610ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.execution.aggregate.{Aggregate => AggregateExec} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} @@ -27,20 +28,11 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto */ object AggUtils { - private[execution] def isAggregate(operator: SparkPlan): Boolean = { - operator.isInstanceOf[HashAggregateExec] || operator.isInstanceOf[SortAggregateExec] - } - - private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = { - assert(isAggregate(operator)) - def supportPartial(exprs: Seq[AggregateExpression]) = - exprs.map(_.aggregateFunction).forall(_.supportsPartial) - operator match { - case agg @ HashAggregateExec(_, _, aggregateExpressions, _, _, _, _) => - supportPartial(aggregateExpressions) - case agg @ SortAggregateExec(_, _, aggregateExpressions, _, _, _, _) => - supportPartial(aggregateExpressions) - } + private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = operator match { + case agg: AggregateExec => + agg.aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) + case _ => + false } private def createPartialAggregateExec( @@ -86,23 +78,18 @@ object AggUtils { private[execution] def createPartialAggregate(operator: SparkPlan) : (SparkPlan, SparkPlan) = operator match { - case agg @ HashAggregateExec(_, groupingExpressions, aggregateExpressions, _, _, _, child) => - val mapSideAgg = createPartialAggregateExec( - groupingExpressions, aggregateExpressions, child) - val mergeAgg = agg.copy( - groupingExpressions = groupingExpressions.map(_.toAttribute), - aggregateExpressions = updateMergeAggregateMode(aggregateExpressions), - initialInputBufferOffset = groupingExpressions.length) - - (mergeAgg, mapSideAgg) - - case agg @ SortAggregateExec(_, groupingExpressions, aggregateExpressions, _, _, _, child) => + case agg: Aggregate => val mapSideAgg = createPartialAggregateExec( - groupingExpressions, aggregateExpressions, child) - val mergeAgg = agg.copy( - groupingExpressions = groupingExpressions.map(_.toAttribute), - aggregateExpressions = updateMergeAggregateMode(aggregateExpressions), - initialInputBufferOffset = groupingExpressions.length) + agg.groupingExpressions, agg.aggregateExpressions, agg.child) + val mergeAgg = createAggregateExec( + requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions, + groupingExpressions = agg.groupingExpressions.map(_.toAttribute), + aggregateExpressions = updateMergeAggregateMode(agg.aggregateExpressions), + aggregateAttributes = agg.aggregateAttributes, + initialInputBufferOffset = agg.groupingExpressions.length, + resultExpressions = agg.resultExpressions, + child = agg.child + ) (mergeAgg, mapSideAgg) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala new file mode 100644 index 0000000000000..657af5403f378 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.SparkPlan + +/** + * A base class for aggregate implementation. + */ +trait Aggregate { + self: SparkPlan => + + val requiredChildDistributionExpressions: Option[Seq[Expression]] + val groupingExpressions: Seq[NamedExpression] + val aggregateExpressions: Seq[AggregateExpression] + val aggregateAttributes: Seq[Attribute] + val initialInputBufferOffset: Int + val resultExpressions: Seq[NamedExpression] + val child: SparkPlan + + protected[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index cfc47aba889aa..30251cf1a051a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} @@ -42,11 +41,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } + extends UnaryExecNode with Aggregate with CodegenSupport { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -60,21 +55,6 @@ case class HashAggregateExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2a81a823c44b3..047fe1e5c711b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.Utils @@ -38,30 +37,11 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) + extends UnaryExecNode with Aggregate { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 509bfb365f42d..a4b3cb463b219 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.AggUtils +import org.apache.spark.sql.execution.aggregate.{Aggregate, AggUtils} import org.apache.spark.sql.internal.SQLConf /** @@ -155,36 +155,28 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == operator.children.length) assert(requiredChildOrderings.length == operator.children.length) - // Ensure that the operator's children satisfy their output distribution requirements: - val childrenWithDist = operator.children.zip(requiredChildDistributions) - def createShuffleExchange(dist: Distribution, child: SparkPlan) = ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child) - var (parent, children) = if (!AggUtils.isAggregate(operator)) { - val newChildren = childrenWithDist.map { - case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - child - case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) - case (child, distribution) => - createShuffleExchange(distribution, child) - } - (operator, newChildren) - } else { - val (child, distribution) = childrenWithDist.head - if (!child.outputPartitioning.satisfies(distribution)) { - if (AggUtils.supportPartialAggregate(operator)) { - // If an aggregation needs a shuffle and support partial aggregations, a map-side partial - // aggregation and a shuffle are added as children. - val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) - (mergeAgg, createShuffleExchange(distribution, mapSideAgg) :: Nil) - } else { - (operator, createShuffleExchange(distribution, child) :: Nil) + var (parent, children) = operator match { + case agg if AggUtils.supportPartialAggregate(agg) && + !operator.outputPartitioning.satisfies(requiredChildDistributions.head) => + // If an aggregation needs a shuffle and support partial aggregations, a map-side partial + // aggregation and a shuffle are added as children. + val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) + (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil) + case _ => + // Ensure that the operator's children satisfy their output distribution requirements: + val childrenWithDist = operator.children.zip(requiredChildDistributions) + val newChildren = childrenWithDist.map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchangeExec(mode, child) + case (child, distribution) => + createShuffleExchange(distribution, child) } - } else { - (operator, child :: Nil) - } + (operator, newChildren) } // If the operator has multiple children and specifies child output distributions (e.g. join), From 8b643055e47d6213847d0cf10adf83a3f59ece55 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 22 Aug 2016 21:48:43 +0900 Subject: [PATCH 16/21] Replace isSupportPartial with an extractor --- .../sql/execution/aggregate/AggUtils.scala | 20 +++++++++++++------ .../exchange/EnsureRequirements.scala | 7 ++++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 580ce16610ecd..e286561de9fb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,21 +19,29 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution.aggregate.{Aggregate => AggregateExec} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} /** - * Utility functions used by the query planner to convert our plan to new aggregation code path. + * A pattern that finds aggregate operators to support partial aggregations. */ -object AggUtils { +object ExtractPartialAggregate { - private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = operator match { - case agg: AggregateExec => - agg.aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) + def unapply(plan: SparkPlan): Option[Distribution] = plan match { + case agg: AggregateExec + if agg.aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) => + Some(agg.requiredChildDistribution.head) case _ => - false + None } +} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object AggUtils { private def createPartialAggregateExec( groupingExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index a4b3cb463b219..7afae71a8a0fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,7 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.{Aggregate, AggUtils} +import org.apache.spark.sql.execution.aggregate.AggUtils +import org.apache.spark.sql.execution.aggregate.ExtractPartialAggregate import org.apache.spark.sql.internal.SQLConf /** @@ -159,8 +160,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child) var (parent, children) = operator match { - case agg if AggUtils.supportPartialAggregate(agg) && - !operator.outputPartitioning.satisfies(requiredChildDistributions.head) => + case ExtractPartialAggregate(childDist) + if !operator.outputPartitioning.satisfies(childDist) => // If an aggregation needs a shuffle and support partial aggregations, a map-side partial // aggregation and a shuffle are added as children. val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) From 0375ac69a517092a6ac6bb412b6ffb1509835c8a Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Tue, 23 Aug 2016 00:09:15 +0900 Subject: [PATCH 17/21] Rename the extractor --- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 2 +- .../spark/sql/execution/exchange/EnsureRequirements.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index e286561de9fb6..b5d6708853e29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto /** * A pattern that finds aggregate operators to support partial aggregations. */ -object ExtractPartialAggregate { +object PartialAggregate { def unapply(plan: SparkPlan): Option[Distribution] = plan match { case agg: AggregateExec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 7afae71a8a0fd..c5337bac5bb1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.AggUtils -import org.apache.spark.sql.execution.aggregate.ExtractPartialAggregate +import org.apache.spark.sql.execution.aggregate.PartialAggregate import org.apache.spark.sql.internal.SQLConf /** @@ -160,8 +160,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child) var (parent, children) = operator match { - case ExtractPartialAggregate(childDist) - if !operator.outputPartitioning.satisfies(childDist) => + case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) => // If an aggregation needs a shuffle and support partial aggregations, a map-side partial // aggregation and a shuffle are added as children. val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) From 86068d0f9db2cd1be91e5ec0c56d6c7c074438c8 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Tue, 23 Aug 2016 01:49:19 +0900 Subject: [PATCH 18/21] Replace trait with sueprclass --- .../sql/execution/aggregate/AggUtils.scala | 3 +-- .../{Aggregate.scala => AggregateExec.scala} | 18 ++++++++---------- .../aggregate/HashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 4 ++-- 4 files changed, 12 insertions(+), 15 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/{Aggregate.scala => AggregateExec.scala} (82%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index b5d6708853e29..195d134260cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.Distribution -import org.apache.spark.sql.execution.aggregate.{Aggregate => AggregateExec} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} @@ -86,7 +85,7 @@ object AggUtils { private[execution] def createPartialAggregate(operator: SparkPlan) : (SparkPlan, SparkPlan) = operator match { - case agg: Aggregate => + case agg: AggregateExec => val mapSideAgg = createPartialAggregateExec( agg.groupingExpressions, agg.aggregateExpressions, agg.child) val mergeAgg = createAggregateExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala index 657af5403f378..b88a8aa3daecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -21,20 +21,19 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode /** * A base class for aggregate implementation. */ -trait Aggregate { - self: SparkPlan => +abstract class AggregateExec extends UnaryExecNode { - val requiredChildDistributionExpressions: Option[Seq[Expression]] - val groupingExpressions: Seq[NamedExpression] - val aggregateExpressions: Seq[AggregateExpression] - val aggregateAttributes: Seq[Attribute] - val initialInputBufferOffset: Int - val resultExpressions: Seq[NamedExpression] - val child: SparkPlan + def requiredChildDistributionExpressions: Option[Seq[Expression]] + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def initialInputBufferOffset: Int + def resultExpressions: Seq[NamedExpression] protected[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -47,7 +46,6 @@ trait Aggregate { override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.isEmpty => AllTuples :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 30251cf1a051a..fa977e5ba5bdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -41,7 +41,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with Aggregate with CodegenSupport { + extends AggregateExec with CodegenSupport { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 047fe1e5c711b..68f86fca80937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.Utils @@ -37,7 +37,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with Aggregate { + extends AggregateExec { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) From 69751bb5d22f9f8a5505ddbaa035f664309ce30b Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Tue, 23 Aug 2016 17:12:24 +0900 Subject: [PATCH 19/21] Apply minor comments --- .../apache/spark/sql/execution/aggregate/AggUtils.scala | 9 +++++++-- .../sql/execution/exchange/EnsureRequirements.scala | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 195d134260cbc..244ae5b992609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -83,8 +83,13 @@ object AggUtils { aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode))) } - private[execution] def createPartialAggregate(operator: SparkPlan) - : (SparkPlan, SparkPlan) = operator match { + /** + * Builds new merge and map-side [[AggregateExec]]s from an input aggregate operator. + * If an aggregation needs a shuffle for satisfying its own distribution and supports partial + * aggregations, a map-side aggregation is appended before the shuffle in + * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]]. + */ + def createMapMergeAggregatePair(operator: SparkPlan): (SparkPlan, SparkPlan) = operator match { case agg: AggregateExec => val mapSideAgg = createPartialAggregateExec( agg.groupingExpressions, agg.aggregateExpressions, agg.child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c5337bac5bb1e..951051c4df2f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -163,7 +163,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) => // If an aggregation needs a shuffle and support partial aggregations, a map-side partial // aggregation and a shuffle are added as children. - val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator) + val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator) (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil) case _ => // Ensure that the operator's children satisfy their output distribution requirements: From d5e0ed3d0efdc8047948e48cdc6fb1257cc381f0 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 24 Aug 2016 10:09:46 +0900 Subject: [PATCH 20/21] Add a new function to check if it supports partial agg --- .../spark/sql/execution/aggregate/AggUtils.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 244ae5b992609..2cf5b005cf70a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto object PartialAggregate { def unapply(plan: SparkPlan): Option[Distribution] = plan match { - case agg: AggregateExec - if agg.aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) => + case agg: AggregateExec if AggUtils.supportPartialAggregate(agg.aggregateExpressions) => Some(agg.requiredChildDistribution.head) case _ => None @@ -42,6 +41,10 @@ object PartialAggregate { */ object AggUtils { + def supportPartialAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) + } + private def createPartialAggregateExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -116,7 +119,7 @@ object AggUtils { child: SparkPlan): SparkPlan = { val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) && - aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) + supportPartialAggregate(aggregateExpressions) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, @@ -146,7 +149,7 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - val supportPartial = aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) + val supportPartial = supportPartialAggregate(aggregateExpressions) createAggregateExec( requiredChildDistributionExpressions = From ac6814514e091ffc377fde75d5c79b583c878626 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 25 Aug 2016 10:38:53 +0900 Subject: [PATCH 21/21] Apply comments --- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 2cf5b005cf70a..fe75ecea177a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -103,7 +103,7 @@ object AggUtils { aggregateAttributes = agg.aggregateAttributes, initialInputBufferOffset = agg.groupingExpressions.length, resultExpressions = agg.resultExpressions, - child = agg.child + child = mapSideAgg ) (mergeAgg, mapSideAgg)