diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c8ed4190a13ad..ecf895dc6e240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,20 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.annotation.tailrec -import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8e8210e334a1d..0fa54c061b7d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.aggregate.MergePartialAggregate import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} @@ -105,6 +106,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), + MergePartialAggregate, CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 0000000000000..1b0d02122cb30 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * A base class for aggregate implementation. + */ +abstract class AggregateExec extends UnaryExecNode { + + def requiredChildDistributionExpressions: Option[Seq[Expression]] + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def initialInputBufferOffset: Int + def resultExpressions: Seq[NamedExpression] + def child: SparkPlan + + protected[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + def copy( + requiredChildDistributionExpressions: Option[Seq[Expression]] = + requiredChildDistributionExpressions, + groupingExpressions: Seq[NamedExpression] = groupingExpressions, + aggregateExpressions: Seq[AggregateExpression] = aggregateExpressions, + aggregateAttributes: Seq[Attribute] = aggregateAttributes, + initialInputBufferOffset: Int = initialInputBufferOffset, + resultExpressions: Seq[NamedExpression] = resultExpressions, + child: SparkPlan = child): AggregateExec +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 68c8e6ce62cbb..32d3f89a2d8a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -43,11 +43,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } + extends AggregateExec with CodegenSupport { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -61,23 +57,8 @@ case class HashAggregateExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { @@ -873,6 +854,25 @@ case class HashAggregateExec( """ } + def copy( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan): AggregateExec = { + new HashAggregateExec( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child + ) + } + override def verboseString: String = toString(verbose = true) override def simpleString: String = toString(verbose = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala new file mode 100644 index 0000000000000..e3803d476a2e0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + + +/** + * A pattern that finds a aggregate pair: a partial aggregate and its parent. + */ +object AggregatePair { + + def unapply(plan: SparkPlan): Option[(AggregateExec, AggregateExec)] = plan match { + case outer: AggregateExec if outer.child.isInstanceOf[AggregateExec] => + val inner = outer.child.asInstanceOf[AggregateExec] + + // Check if classes and grouping keys are the same with each other to make sure the two + // aggregates are for the same group by a GROUP-BY clause. + if (outer.getClass == inner.getClass && + outer.groupingExpressions.map(_.toAttribute) == + inner.groupingExpressions.map(_.toAttribute)) { + Some(outer, inner) + } else { + None + } + + case _ => + None + } +} + +/** + * Merge partial (map-side) aggregates into their parent aggregates if the parent aggregates + * directly have the partial aggregates as children. + */ +object MergePartialAggregate extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = plan transform { + // Normal pair: a partial aggregate and its parent + case AggregatePair(outer, inner) + if outer.aggregateExpressions.forall(_.mode == Final) && + inner.aggregateExpressions.forall(expr => expr.mode == Partial && !expr.isDistinct) => + inner.copy( + aggregateExpressions = inner.aggregateExpressions.map(_.copy(mode = Complete)), + aggregateAttributes = inner.aggregateExpressions.map(_.resultAttribute), + resultExpressions = outer.resultExpressions) + + // First partial aggregate pair for aggregation with distinct + case AggregatePair(outer, inner) + if (outer.aggregateExpressions.forall(_.mode == PartialMerge) && + inner.aggregateExpressions.forall(_.mode == Partial)) || + // If a query has a single distinct aggregate (that is, it has no non-distinct aggregate), + // the first aggregate pair has empty aggregate functions. + Seq(outer, inner).forall(_.aggregateExpressions.isEmpty) => + inner + + // Second partial aggregate pair for aggregation with distinct. If input data are already + // partitioned and the same columns are used in grouping keys and aggregation values, + // a distribution requirement of a second partial aggregate is satisfied and then we can safely + // merge the second aggregate into its parent. + // A query example of this case is; + // + // SELECT t.value, SUM(DISTINCT t.value) + // FROM (SELECT * FROM inputTable ORDER BY value) t + // GROUP BY t.value + case AggregatePair(outer, inner) + if outer.aggregateExpressions.forall(_.mode == Final) && + inner.aggregateExpressions.exists(_.isDistinct) => + outer.copy( + aggregateExpressions = outer.aggregateExpressions.map { + case funcWithDistinct if funcWithDistinct.isDistinct => + funcWithDistinct.copy(mode = Complete) + case otherFunc => + otherFunc + }, + child = inner.child + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index be3198b8e7d82..efd91dd6bcdca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -38,30 +38,11 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) + extends AggregateExec { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } @@ -106,6 +87,25 @@ case class SortAggregateExec( } } + def copy( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan): AggregateExec = { + new SortAggregateExec( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child + ) + } + override def simpleString: String = toString(verbose = false) override def verboseString: String = toString(verbose = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 610ce5e1ebf5d..dcab180307adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52bd4e19f8952..fa63d27e51136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1236,17 +1236,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } /** - * Verifies that there is no Exchange between the Aggregations for `df` + * Verifies that there is a single Aggregation for `df` */ - private def verifyNonExchangingAgg(df: DataFrame) = { + private def verifyNonExchangingSingleAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { case agg: HashAggregateExec => - atFirstAgg = !atFirstAgg - case _ => if (atFirstAgg) { - fail("Should not have operators between the two aggregations") + fail("Should not have back to back Aggregates") } + atFirstAgg = true + case _ => } } @@ -1280,9 +1280,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Group by the column we are distributed by. This should generate a plan with no exchange // between the aggregates val df3 = testData.repartition($"key").groupBy("key").count() - verifyNonExchangingAgg(df3) - verifyNonExchangingAgg(testData.repartition($"key", $"value") + verifyNonExchangingSingleAgg(df3) + verifyNonExchangingSingleAgg(testData.repartition($"key", $"value") .groupBy("key", "value").count()) + verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count()) // Grouping by just the first distributeBy expr, need to exchange. verifyExchangingAgg(testData.repartition($"key", $"value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 02ccebd22bdf9..89cfc7e23fabe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext { setupTestData() - private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = { val planner = spark.sessionState.planner import planner._ - val plannedOption = Aggregation(query).headOption - val planned = - plannedOption.getOrElse( - fail(s"Could query play aggregation query $query. Is it an aggregation query?")) - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - // For the new aggregation code path, there will be four aggregate operator for - // distinct aggregations. - assert( - aggregations.size == 2 || aggregations.size == 4, - s"The plan of query $query does not have partial aggregations.") + val ensureRequirements = EnsureRequirements(spark.sessionState.conf) + val planned = Aggregation(query).headOption.map(ensureRequirements(_)) + .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + planned.collect { case n if n.nodeName contains "Aggregate" => n } } test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - testPartialAggregationPlan(query) + assert(testPartialAggregationPlan(query).size == 2, + s"The plan of query $query does not have partial aggregations.") } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed testPartialAggregationPlan(query) + // For the new aggregation code path, there will be four aggregate operator for distinct + // aggregations. + assert(testPartialAggregationPlan(query).size == 4, + s"The plan of query $query does not have partial aggregations.") } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - testPartialAggregationPlan(query) + // For the new aggregation code path, there will be four aggregate operator for distinct + // aggregations. + assert(testPartialAggregationPlan(query).size == 4, + s"The plan of query $query does not have partial aggregations.") + } + + test("non-partial aggregation for aggregates") { + withTempView("testNonPartialAggregation") { + val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) + val row = Row.fromSeq(Seq.fill(1)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + spark.createDataFrame(rowRDD, schema).repartition($"value") + .createOrReplaceTempView("testNonPartialAggregation") + + val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value") + .queryExecution.executedPlan + + // If input data are already partitioned and the same columns are used in grouping keys and + // aggregation values, no partial aggregation should exist in query plans. + val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.") + + val planned2 = sql( + """ + |SELECT t.value, SUM(DISTINCT t.value) + |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t + |GROUP BY t.value + """.stripMargin).queryExecution.executedPlan + + val aggOps2 = planned2.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps2.size == 2, s"The plan $planned2 has partial aggregations.") + } } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {