Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val COLLAPSE_AGGREGATE_NODES_ENABLED = buildConf("spark.sql.execution.collapseAggregateNodes")
.internal()
.doc("Whether to collapse the Partial and the Final aggregate exec nodes based " +
"on whether there is exchange between them")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
.internal()
.doc("Whether to remove redundant physical sort node")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf

/**
* Collapse Physical aggregate exec nodes together if there is no exchange between them and they
* correspond to Partial and Final Aggregation for same
* [[org.apache.spark.sql.catalyst.plans.logical.Aggregate]] logical node.
*/
object CollapseAggregates extends Rule[SparkPlan] {

override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.COLLAPSE_AGGREGATE_NODES_ENABLED)) {
plan
} else {
collapseAggregates(plan)
}
}

private def collapseAggregates(plan: SparkPlan): SparkPlan = {
plan transform {
case parent@HashAggregateExec(_, _, _, _, _, _, child: HashAggregateExec)
if checkIfAggregatesCanBeCollapsed(parent, child) =>
val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete))
HashAggregateExec(
requiredChildDistributionExpressions = Some(child.groupingExpressions),
groupingExpressions = child.groupingExpressions,
aggregateExpressions = completeAggregateExpressions,
aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute),
initialInputBufferOffset = 0,
resultExpressions = parent.resultExpressions,
child = child.child)

case parent@SortAggregateExec(_, _, _, _, _, _, child: SortAggregateExec)
if checkIfAggregatesCanBeCollapsed(parent, child) =>
val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete))
SortAggregateExec(
requiredChildDistributionExpressions = Some(child.groupingExpressions),
groupingExpressions = child.groupingExpressions,
aggregateExpressions = completeAggregateExpressions,
aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute),
initialInputBufferOffset = 0,
resultExpressions = parent.resultExpressions,
child = child.child)

case parent@ObjectHashAggregateExec(_, _, _, _, _, _, child: ObjectHashAggregateExec)
if checkIfAggregatesCanBeCollapsed(parent, child) =>
val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete))
ObjectHashAggregateExec(
requiredChildDistributionExpressions = Some(child.groupingExpressions),
groupingExpressions = child.groupingExpressions,
aggregateExpressions = completeAggregateExpressions,
aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute),
initialInputBufferOffset = 0,
resultExpressions = parent.resultExpressions,
child = child.child)
}
}

private def checkIfAggregatesCanBeCollapsed(
parent: BaseAggregateExec,
child: BaseAggregateExec): Boolean = {
val parentHasFinalMode = parent.aggregateExpressions.forall(_.mode == Final)
if (!parentHasFinalMode) {
return false
}
val childHasPartialMode = child.aggregateExpressions.forall(_.mode == Partial)
if (!childHasPartialMode) {
return false
}
val parentChildAggExpressionsSame = parent.aggregateExpressions.map(
_.copy(mode = Partial)) == child.aggregateExpressions
if (!parentChildAggExpressionsSame) {
return false
}
true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ object QueryExecution {
// `RemoveRedundantSorts` needs to be added before `EnsureRequirements` to guarantee the same
// number of partitions when instantiating PartitioningCollection.
RemoveRedundantSorts,
CollapseAggregates,
DisableUnnecessaryBucketedScan,
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.columnarRules),
CollapseCodegenStages(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ case class AdaptiveSparkPlanExec(
RemoveRedundantProjects,
EnsureRequirements,
RemoveRedundantSorts,
CollapseAggregates,
DisableUnnecessaryBucketedScan
) ++ context.session.sessionState.queryStagePrepRules

Expand Down
Original file line number Diff line number Diff line change
@@ -1,55 +1,54 @@
== Physical Plan ==
TakeOrderedAndProject (51)
+- * Project (50)
+- * SortMergeJoin Inner (49)
:- * Sort (46)
: +- Exchange (45)
: +- * Project (44)
: +- * SortMergeJoin Inner (43)
: :- * Sort (37)
: : +- Exchange (36)
: : +- * HashAggregate (35)
: : +- * HashAggregate (34)
: : +- * Project (33)
: : +- * SortMergeJoin Inner (32)
: : :- * Sort (26)
: : : +- Exchange (25)
: : : +- * Project (24)
: : : +- * BroadcastHashJoin Inner BuildRight (23)
: : : :- * Project (17)
: : : : +- * BroadcastHashJoin Inner BuildRight (16)
: : : : :- * Project (10)
: : : : : +- * BroadcastHashJoin Inner BuildRight (9)
: : : : : :- * Filter (3)
: : : : : : +- * ColumnarToRow (2)
: : : : : : +- Scan parquet default.store_sales (1)
: : : : : +- BroadcastExchange (8)
: : : : : +- * Project (7)
: : : : : +- * Filter (6)
: : : : : +- * ColumnarToRow (5)
: : : : : +- Scan parquet default.date_dim (4)
: : : : +- BroadcastExchange (15)
: : : : +- * Project (14)
: : : : +- * Filter (13)
: : : : +- * ColumnarToRow (12)
: : : : +- Scan parquet default.store (11)
: : : +- BroadcastExchange (22)
: : : +- * Project (21)
: : : +- * Filter (20)
: : : +- * ColumnarToRow (19)
: : : +- Scan parquet default.household_demographics (18)
: : +- * Sort (31)
: : +- Exchange (30)
: : +- * Filter (29)
: : +- * ColumnarToRow (28)
: : +- Scan parquet default.customer_address (27)
: +- * Sort (42)
: +- Exchange (41)
: +- * Filter (40)
: +- * ColumnarToRow (39)
: +- Scan parquet default.customer (38)
+- * Sort (48)
+- ReusedExchange (47)
TakeOrderedAndProject (50)
+- * Project (49)
+- * SortMergeJoin Inner (48)
:- * Sort (45)
: +- Exchange (44)
: +- * Project (43)
: +- * SortMergeJoin Inner (42)
: :- * Sort (36)
: : +- Exchange (35)
: : +- * HashAggregate (34)
: : +- * Project (33)
: : +- * SortMergeJoin Inner (32)
: : :- * Sort (26)
: : : +- Exchange (25)
: : : +- * Project (24)
: : : +- * BroadcastHashJoin Inner BuildRight (23)
: : : :- * Project (17)
: : : : +- * BroadcastHashJoin Inner BuildRight (16)
: : : : :- * Project (10)
: : : : : +- * BroadcastHashJoin Inner BuildRight (9)
: : : : : :- * Filter (3)
: : : : : : +- * ColumnarToRow (2)
: : : : : : +- Scan parquet default.store_sales (1)
: : : : : +- BroadcastExchange (8)
: : : : : +- * Project (7)
: : : : : +- * Filter (6)
: : : : : +- * ColumnarToRow (5)
: : : : : +- Scan parquet default.date_dim (4)
: : : : +- BroadcastExchange (15)
: : : : +- * Project (14)
: : : : +- * Filter (13)
: : : : +- * ColumnarToRow (12)
: : : : +- Scan parquet default.store (11)
: : : +- BroadcastExchange (22)
: : : +- * Project (21)
: : : +- * Filter (20)
: : : +- * ColumnarToRow (19)
: : : +- Scan parquet default.household_demographics (18)
: : +- * Sort (31)
: : +- Exchange (30)
: : +- * Filter (29)
: : +- * ColumnarToRow (28)
: : +- Scan parquet default.customer_address (27)
: +- * Sort (41)
: +- Exchange (40)
: +- * Filter (39)
: +- * ColumnarToRow (38)
: +- Scan parquet default.customer (37)
+- * Sort (47)
+- ReusedExchange (46)


(1) Scan parquet default.store_sales
Expand Down Expand Up @@ -201,81 +200,74 @@ Input [7]: [ss_customer_sk#2, ss_addr_sk#4, ss_ticket_number#6, ss_coupon_amt#7,
(34) HashAggregate [codegen id : 8]
Input [6]: [ss_customer_sk#2, ss_addr_sk#4, ss_ticket_number#6, ss_coupon_amt#7, ss_net_profit#8, ca_city#22]
Keys [4]: [ss_ticket_number#6, ss_customer_sk#2, ss_addr_sk#4, ca_city#22]
Functions [2]: [partial_sum(UnscaledValue(ss_coupon_amt#7)), partial_sum(UnscaledValue(ss_net_profit#8))]
Aggregate Attributes [2]: [sum#24, sum#25]
Results [6]: [ss_ticket_number#6, ss_customer_sk#2, ss_addr_sk#4, ca_city#22, sum#26, sum#27]

(35) HashAggregate [codegen id : 8]
Input [6]: [ss_ticket_number#6, ss_customer_sk#2, ss_addr_sk#4, ca_city#22, sum#26, sum#27]
Keys [4]: [ss_ticket_number#6, ss_customer_sk#2, ss_addr_sk#4, ca_city#22]
Functions [2]: [sum(UnscaledValue(ss_coupon_amt#7)), sum(UnscaledValue(ss_net_profit#8))]
Aggregate Attributes [2]: [sum(UnscaledValue(ss_coupon_amt#7))#28, sum(UnscaledValue(ss_net_profit#8))#29]
Results [5]: [ss_ticket_number#6, ss_customer_sk#2, ca_city#22 AS bought_city#30, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#7))#28,17,2) AS amt#31, MakeDecimal(sum(UnscaledValue(ss_net_profit#8))#29,17,2) AS profit#32]
Aggregate Attributes [2]: [sum(UnscaledValue(ss_coupon_amt#7))#24, sum(UnscaledValue(ss_net_profit#8))#25]
Results [5]: [ss_ticket_number#6, ss_customer_sk#2, ca_city#22 AS bought_city#26, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#7))#24,17,2) AS amt#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#8))#25,17,2) AS profit#28]

(36) Exchange
Input [5]: [ss_ticket_number#6, ss_customer_sk#2, bought_city#30, amt#31, profit#32]
Arguments: hashpartitioning(ss_customer_sk#2, 5), true, [id=#33]
(35) Exchange
Input [5]: [ss_ticket_number#6, ss_customer_sk#2, bought_city#26, amt#27, profit#28]
Arguments: hashpartitioning(ss_customer_sk#2, 5), true, [id=#29]

(37) Sort [codegen id : 9]
Input [5]: [ss_ticket_number#6, ss_customer_sk#2, bought_city#30, amt#31, profit#32]
(36) Sort [codegen id : 9]
Input [5]: [ss_ticket_number#6, ss_customer_sk#2, bought_city#26, amt#27, profit#28]
Arguments: [ss_customer_sk#2 ASC NULLS FIRST], false, 0

(38) Scan parquet default.customer
Output [4]: [c_customer_sk#34, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
(37) Scan parquet default.customer
Output [4]: [c_customer_sk#30, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Batched: true
Location [not included in comparison]/{warehouse_dir}/customer]
PushedFilters: [IsNotNull(c_customer_sk), IsNotNull(c_current_addr_sk)]
ReadSchema: struct<c_customer_sk:int,c_current_addr_sk:int,c_first_name:string,c_last_name:string>

(39) ColumnarToRow [codegen id : 10]
Input [4]: [c_customer_sk#34, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
(38) ColumnarToRow [codegen id : 10]
Input [4]: [c_customer_sk#30, c_current_addr_sk#31, c_first_name#32, c_last_name#33]

(40) Filter [codegen id : 10]
Input [4]: [c_customer_sk#34, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
Condition : (isnotnull(c_customer_sk#34) AND isnotnull(c_current_addr_sk#35))
(39) Filter [codegen id : 10]
Input [4]: [c_customer_sk#30, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Condition : (isnotnull(c_customer_sk#30) AND isnotnull(c_current_addr_sk#31))

(41) Exchange
Input [4]: [c_customer_sk#34, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
Arguments: hashpartitioning(c_customer_sk#34, 5), true, [id=#38]
(40) Exchange
Input [4]: [c_customer_sk#30, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Arguments: hashpartitioning(c_customer_sk#30, 5), true, [id=#34]

(42) Sort [codegen id : 11]
Input [4]: [c_customer_sk#34, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
Arguments: [c_customer_sk#34 ASC NULLS FIRST], false, 0
(41) Sort [codegen id : 11]
Input [4]: [c_customer_sk#30, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Arguments: [c_customer_sk#30 ASC NULLS FIRST], false, 0

(43) SortMergeJoin [codegen id : 12]
(42) SortMergeJoin [codegen id : 12]
Left keys [1]: [ss_customer_sk#2]
Right keys [1]: [c_customer_sk#34]
Right keys [1]: [c_customer_sk#30]
Join condition: None

(44) Project [codegen id : 12]
Output [7]: [ss_ticket_number#6, bought_city#30, amt#31, profit#32, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
Input [9]: [ss_ticket_number#6, ss_customer_sk#2, bought_city#30, amt#31, profit#32, c_customer_sk#34, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
(43) Project [codegen id : 12]
Output [7]: [ss_ticket_number#6, bought_city#26, amt#27, profit#28, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Input [9]: [ss_ticket_number#6, ss_customer_sk#2, bought_city#26, amt#27, profit#28, c_customer_sk#30, c_current_addr_sk#31, c_first_name#32, c_last_name#33]

(45) Exchange
Input [7]: [ss_ticket_number#6, bought_city#30, amt#31, profit#32, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
Arguments: hashpartitioning(c_current_addr_sk#35, 5), true, [id=#39]
(44) Exchange
Input [7]: [ss_ticket_number#6, bought_city#26, amt#27, profit#28, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Arguments: hashpartitioning(c_current_addr_sk#31, 5), true, [id=#35]

(46) Sort [codegen id : 13]
Input [7]: [ss_ticket_number#6, bought_city#30, amt#31, profit#32, c_current_addr_sk#35, c_first_name#36, c_last_name#37]
Arguments: [c_current_addr_sk#35 ASC NULLS FIRST], false, 0
(45) Sort [codegen id : 13]
Input [7]: [ss_ticket_number#6, bought_city#26, amt#27, profit#28, c_current_addr_sk#31, c_first_name#32, c_last_name#33]
Arguments: [c_current_addr_sk#31 ASC NULLS FIRST], false, 0

(47) ReusedExchange [Reuses operator id: 30]
(46) ReusedExchange [Reuses operator id: 30]
Output [2]: [ca_address_sk#21, ca_city#22]

(48) Sort [codegen id : 15]
(47) Sort [codegen id : 15]
Input [2]: [ca_address_sk#21, ca_city#22]
Arguments: [ca_address_sk#21 ASC NULLS FIRST], false, 0

(49) SortMergeJoin [codegen id : 16]
Left keys [1]: [c_current_addr_sk#35]
(48) SortMergeJoin [codegen id : 16]
Left keys [1]: [c_current_addr_sk#31]
Right keys [1]: [ca_address_sk#21]
Join condition: NOT (ca_city#22 = bought_city#30)
Join condition: NOT (ca_city#22 = bought_city#26)

(50) Project [codegen id : 16]
Output [7]: [c_last_name#37, c_first_name#36, ca_city#22, bought_city#30, ss_ticket_number#6, amt#31, profit#32]
Input [9]: [ss_ticket_number#6, bought_city#30, amt#31, profit#32, c_current_addr_sk#35, c_first_name#36, c_last_name#37, ca_address_sk#21, ca_city#22]
(49) Project [codegen id : 16]
Output [7]: [c_last_name#33, c_first_name#32, ca_city#22, bought_city#26, ss_ticket_number#6, amt#27, profit#28]
Input [9]: [ss_ticket_number#6, bought_city#26, amt#27, profit#28, c_current_addr_sk#31, c_first_name#32, c_last_name#33, ca_address_sk#21, ca_city#22]

(51) TakeOrderedAndProject
Input [7]: [c_last_name#37, c_first_name#36, ca_city#22, bought_city#30, ss_ticket_number#6, amt#31, profit#32]
Arguments: 100, [c_last_name#37 ASC NULLS FIRST, c_first_name#36 ASC NULLS FIRST, ca_city#22 ASC NULLS FIRST, bought_city#30 ASC NULLS FIRST, ss_ticket_number#6 ASC NULLS FIRST], [c_last_name#37, c_first_name#36, ca_city#22, bought_city#30, ss_ticket_number#6, amt#31, profit#32]
(50) TakeOrderedAndProject
Input [7]: [c_last_name#33, c_first_name#32, ca_city#22, bought_city#26, ss_ticket_number#6, amt#27, profit#28]
Arguments: 100, [c_last_name#33 ASC NULLS FIRST, c_first_name#32 ASC NULLS FIRST, ca_city#22 ASC NULLS FIRST, bought_city#26 ASC NULLS FIRST, ss_ticket_number#6 ASC NULLS FIRST], [c_last_name#33, c_first_name#32, ca_city#22, bought_city#26, ss_ticket_number#6, amt#27, profit#28]

Loading