From a5b6d7140cf0de3bab75788fd32b497a8817c93c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Aug 2016 22:44:49 +0800 Subject: [PATCH] remove unnecessary partial aggregate --- .../spark/sql/execution/QueryExecution.scala | 4 +- .../aggregate/MergePartialAggregate.scala | 55 +++++++++++++++++++ .../spark/sql/execution/PlannerSuite.scala | 24 +++++++- 3 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index d4845637be049..6593971171562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,12 +21,13 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.aggregate.MergePartialAggregate import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf @@ -100,6 +101,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), + MergePartialAggregate, CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala new file mode 100644 index 0000000000000..c998b8ff71b81 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergePartialAggregate.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +object MergePartialAggregate extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = plan transform { + // Normal partial aggregate pair + case outer @ HashAggregateExec(_, _, _, _, _, _, inner: HashAggregateExec) + if outer.aggregateExpressions.forall(_.mode == Final) && + inner.aggregateExpressions.forall(_.mode == Partial) => + inner.copy( + aggregateExpressions = inner.aggregateExpressions.map(_.copy(mode = Complete)), + aggregateAttributes = inner.aggregateExpressions.map(_.resultAttribute), + resultExpressions = outer.resultExpressions) + + // First partial aggregate pair for aggregation with distinct + case outer @ HashAggregateExec(_, _, _, _, _, _, inner: HashAggregateExec) + if outer.aggregateExpressions.forall(_.mode == PartialMerge) && + inner.aggregateExpressions.forall(_.mode == Partial) => + inner + + // Second partial aggregate pair for aggregation with distinct. + // This is actually a no-op. For aggregation with distinct, the output of first partial + // aggregate is partitioned by grouping expressions and distinct attributes, and the second + // partial aggregate requires input to be partitioned by grouping attributes, which is not + // satisfied. `EnsureRequirements` will always insert exchange between these 2 aggregate exec + // and we will never hit this branch. + case outer @ HashAggregateExec(_, _, _, _, _, _, inner: HashAggregateExec) + if outer.aggregateExpressions.forall(_.mode == Final) && + inner.aggregateExpressions.forall(_.mode == PartialMerge) => + outer.copy(child = inner.child) + + // Add similar logic for sort aggregate + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 13490c35679a2..c2e644634199b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,21 +18,21 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends SharedSQLContext { +class PlannerSuite extends QueryTest with SharedSQLContext { import testImplicits._ setupTestData() @@ -518,6 +518,24 @@ class PlannerSuite extends SharedSQLContext { fail(s"Should have only two shuffles:\n$outputPlan") } } + + test("no partial aggregation if input relation is already partitioned") { + val input = Seq("a" -> 1, "b" -> 2).toDF("i", "j") + + val aggWithoutDistinct = input.repartition($"i").groupBy($"i").agg(sum($"j")) + checkAnswer(aggWithoutDistinct, input.groupBy($"i").agg(sum($"j"))) + val numShuffles = aggWithoutDistinct.queryExecution.executedPlan.collect { + case e: Exchange => e + }.length + assert(numShuffles == 1) + + val aggWithDistinct = input.repartition($"i", $"j").groupBy($"i").agg(countDistinct($"j")) + checkAnswer(aggWithDistinct, input.groupBy($"i").agg(countDistinct($"j"))) + val numShuffles2 = aggWithDistinct.queryExecution.executedPlan.collect { + case e: Exchange => e + }.length + assert(numShuffles2 == 2) + } } // Used for unit-testing EnsureRequirements