Skip to content

Commit 156d8d9

Browse files
xingchaozhGitHub Enterprise
authored andcommitted
[CARMEL-6151] Exchange Push Down through Aggregate (#1053)
* [CARMEL-6151] Exchange Push Down through Aggregate * fix ut * fix ut
1 parent 53f3368 commit 156d8d9

File tree

5 files changed

+346
-2
lines changed

5 files changed

+346
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,14 @@ object SQLConf {
13031303
.booleanConf
13041304
.createWithDefault(false)
13051305

1306+
val EXCHANGE_PUSH_DOWN_THROUGH_AGGREGATE_ENABLED =
1307+
buildConf("spark.sql.exchangePushDownThroughAggregate.enabled")
1308+
.internal()
1309+
.doc("When true, we will try to push down exchange through aggregate.")
1310+
.version("3.0.0")
1311+
.booleanConf
1312+
.createWithDefault(false)
1313+
13061314
val AUTO_APPLY_STAGE_FALLBACK_PLAN_ENABLED =
13071315
buildConf("spark.sql.applyStageFallbackPlan.enabled")
13081316
.internal()

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ import org.apache.spark.sql.execution.QueryExecution.skipAuthTag
3737
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, EnsureRepartitionForWriting, InsertAdaptiveSparkPlan}
3838
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
3939
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
40-
import org.apache.spark.sql.execution.exchange.EliminateShuffleExec
41-
import org.apache.spark.sql.execution.exchange.EnsureRequirements
40+
import org.apache.spark.sql.execution.exchange.{EliminateShuffleExec, EnsureRequirements, ExchangePushDownThroughAggregate}
4241
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
4342
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
4443
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
@@ -366,6 +365,7 @@ object QueryExecution {
366365
PlanSubqueries(sparkSession),
367366
EliminateHintPlaceHolder,
368367
EnsureRequirements,
368+
ExchangePushDownThroughAggregate,
369369
// `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same
370370
// number of partitions when instantiating PartitioningCollection.
371371
RemoveRedundantSorts,

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ case class AdaptiveSparkPlanExec(
9191
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
9292
eliminateHintPlaceHolder,
9393
ensureRequirements,
94+
ExchangePushDownThroughAggregate,
9495
removeRedundantSorts,
9596
RemoveRedundantPartialAggregates,
9697
EnsureRepartitionForWriting,
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.exchange
19+
20+
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
21+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashClusteredDistribution, HashPartitioning}
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.execution._
24+
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec}
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
/**
28+
* Exchange could be push down through aggregate to reduce shuffle count or
29+
* improve the parallelism of aggregate
30+
*/
31+
object ExchangePushDownThroughAggregate extends Rule[SparkPlan] {
32+
private def isPushDownSupport(agg: BaseAggregateExec,
33+
partialAgg: BaseAggregateExec,
34+
shuffle: ShuffleExchangeExec): Boolean = {
35+
val validFinalAggregate = agg.aggregateExpressions.forall(_.mode == Final)
36+
val validPartialAggregate = partialAgg.aggregateExpressions.forall(_.mode == Partial)
37+
38+
val ensureRequirement = agg.requiredChildDistribution.
39+
forall(shuffle.outputPartitioning.satisfies)
40+
41+
val containsAll = shuffle.outputPartitioning match {
42+
case h: HashPartitioning => h.expressions.forall {
43+
e => agg.groupingExpressions.find(_.semanticEquals(e)).nonEmpty
44+
}
45+
case _ => false
46+
}
47+
48+
validFinalAggregate && validPartialAggregate && ensureRequirement && containsAll
49+
}
50+
51+
private def exchangePushDown(plan: SparkPlan): SparkPlan = plan transform {
52+
/* Aggregate on bucket table
53+
*
54+
* Shuffle(j) HashAggregate(i, ..., Final)
55+
* | |
56+
* HashAggregate(i, ..., Final) HashAggregate(i, ..., Partial)
57+
* | -> |
58+
* HashAggregate(i, ..., Partial) Shuffle(j)
59+
* | |
60+
* Project Project
61+
* | |
62+
* Filter Filter
63+
* | |
64+
* Scan(t1: i, j) Scan(t1: i, j)
65+
* -- bucket on i
66+
*/
67+
case shuffle @ ShuffleExchangeExec(_: HashPartitioning,
68+
agg @ HashAggregateExec(_, _, _, _, _, _,
69+
partialAgg@HashAggregateExec(_, _, _, _, _, _,
70+
ProjectExec(_, FilterExec(_, scan: FileSourceScanExec)))), _)
71+
if isPushDownSupport(agg, partialAgg, shuffle) && scan.relation.bucketSpec.nonEmpty =>
72+
val newShuffle = shuffle.withNewChildren(partialAgg.children)
73+
newShuffle.addOptimizeTag(s"created by ${this.simpleRuleName}")
74+
75+
partialAgg.logicalLink.foreach(newShuffle.setLogicalLink)
76+
val newPartialAgg = partialAgg.withNewChildren(newShuffle :: Nil)
77+
agg.withNewChildren( newPartialAgg :: Nil)
78+
79+
case shuffle @ ShuffleExchangeExec(_: HashPartitioning,
80+
agg @ HashAggregateExec(_, _, _, _, _, _,
81+
partialAgg@HashAggregateExec(_, _, _, _, _, _, scan: FileSourceScanExec)), _)
82+
if isPushDownSupport(agg, partialAgg, shuffle) && scan.relation.bucketSpec.nonEmpty =>
83+
val newShuffle = shuffle.withNewChildren(partialAgg.children)
84+
newShuffle.addOptimizeTag(s"created by ${this.simpleRuleName}")
85+
86+
partialAgg.logicalLink.foreach(newShuffle.setLogicalLink)
87+
val newPartialAgg = partialAgg.withNewChildren(newShuffle :: Nil)
88+
agg.withNewChildren( newPartialAgg :: Nil)
89+
90+
91+
/* Aggregate on non-bucket table
92+
*
93+
* Shuffle(j)
94+
* |
95+
* HashAggregate(i, j, Final) HashAggregate(i, j, Final)
96+
* | |
97+
* Shuffle(i, j) -> HashAggregate(i, j, Partial)
98+
* | |
99+
* HashAggregate(i, j, Partial) Shuffle(j)
100+
* | |
101+
* Project Project
102+
* | |
103+
* Filter Filter
104+
* | |
105+
* Scan(t1: i, j) Scan(t1: i, j)
106+
*/
107+
case shuffle @ ShuffleExchangeExec(_: HashPartitioning,
108+
agg @ HashAggregateExec(_, _, _, _, _, _, ShuffleExchangeExec(_: HashPartitioning,
109+
partialAgg@HashAggregateExec(_, _, _, _, _, _,
110+
ProjectExec(_, FilterExec(_, _: FileSourceScanExec))), _)), _)
111+
if isPushDownSupport(agg, partialAgg, shuffle) =>
112+
val newShuffle = shuffle.withNewChildren(partialAgg.children)
113+
newShuffle.addOptimizeTag(s"created by ${this.simpleRuleName}")
114+
115+
partialAgg.logicalLink.foreach(newShuffle.setLogicalLink)
116+
val newPartialAgg = partialAgg.withNewChildren(newShuffle :: Nil)
117+
agg.withNewChildren( newPartialAgg :: Nil)
118+
119+
case shuffle @ ShuffleExchangeExec(_: HashPartitioning,
120+
agg @ HashAggregateExec(_, _, _, _, _, _, ShuffleExchangeExec(_: HashPartitioning,
121+
partialAgg@HashAggregateExec(_, _, _, _, _, _, _: FileSourceScanExec), _)), _)
122+
if isPushDownSupport(agg, partialAgg, shuffle) =>
123+
val newShuffle = shuffle.withNewChildren(partialAgg.children)
124+
newShuffle.addOptimizeTag(s"created by ${this.simpleRuleName}")
125+
126+
partialAgg.logicalLink.foreach(newShuffle.setLogicalLink)
127+
val newPartialAgg = partialAgg.withNewChildren(newShuffle :: Nil)
128+
agg.withNewChildren( newPartialAgg :: Nil)
129+
}
130+
131+
def apply(plan: SparkPlan): SparkPlan = {
132+
if (!conf.getConf(SQLConf.EXCHANGE_PUSH_DOWN_THROUGH_AGGREGATE_ENABLED)) {
133+
plan
134+
} else {
135+
val newPlan = exchangePushDown(plan)
136+
newPlan
137+
}
138+
}
139+
}
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.sources
19+
20+
import org.apache.spark.sql.{DataFrame, QueryTest}
21+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
22+
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
25+
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
26+
27+
class ExchangePushDownThroughAggregateWithoutHiveSupportSuite
28+
extends ExchangePushDownThroughAggregateSuite
29+
with SharedSparkSession
30+
with DisableAdaptiveExecutionSuite {
31+
32+
protected override def beforeAll(): Unit = {
33+
super.beforeAll()
34+
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
35+
}
36+
}
37+
38+
class ExchangePushDownThroughAggregateWithoutHiveSupportSuiteAE
39+
extends ExchangePushDownThroughAggregateSuite
40+
with SharedSparkSession
41+
with EnableAdaptiveExecutionSuite {
42+
43+
protected override def beforeAll(): Unit = {
44+
super.beforeAll()
45+
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
46+
}
47+
}
48+
49+
abstract class ExchangePushDownThroughAggregateSuite extends QueryTest
50+
with SQLTestUtils with AdaptiveSparkPlanHelper {
51+
52+
// protected override def beforeAll(): Unit = {
53+
// super.beforeAll()
54+
// assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
55+
// }
56+
57+
import testImplicits._
58+
59+
private lazy val df1 =
60+
(0 until 50).map(i => (i % 2, i % 4, i % 8)).toDF("i", "j", "k").as("df1")
61+
private lazy val df2 =
62+
(0 until 50).map(i => (i % 3, i % 6, i % 9)).toDF("i", "j", "k").as("df2")
63+
private lazy val df3 =
64+
(0 until 50).map(i => (i % 3, i % 6, i % 9)).toDF("l", "m", "n").as("df3")
65+
66+
private def checkExchangePushDown(query: String,
67+
shuffleCountIfEnable: Int,
68+
shuffleCountIfDisable: Int): Unit = {
69+
70+
def checkExchangePushDownResult(query: String, enabled: Boolean,
71+
expectedNumShuffle: Int): DataFrame = {
72+
val df = sql(query)
73+
df.collect()
74+
75+
val plan = df.queryExecution.executedPlan
76+
77+
// scalastyle:off println
78+
println(s"query: ${query},\nenabled: ${enabled},\nplan: ${plan}")
79+
// scalastyle:on println
80+
81+
val shuffles = collect(plan) { case s: ShuffleExchangeExec => s }
82+
assert(shuffles.length == expectedNumShuffle)
83+
df
84+
}
85+
86+
withSQLConf(SQLConf.EXCHANGE_PUSH_DOWN_THROUGH_AGGREGATE_ENABLED.key -> "true") {
87+
val result = checkExchangePushDownResult(query, true, shuffleCountIfEnable)
88+
89+
withSQLConf(SQLConf.EXCHANGE_PUSH_DOWN_THROUGH_AGGREGATE_ENABLED.key -> "false") {
90+
val result2 = checkExchangePushDownResult(query, false, shuffleCountIfDisable)
91+
checkAnswer(result, result2)
92+
}
93+
}
94+
}
95+
96+
97+
test("Exchange push down through aggregate - basic test") {
98+
withTable("bucket_table1", "normal_table1", "normal_table2") {
99+
df1.write.format("parquet").bucketBy(8, "i").
100+
saveAsTable("bucket_table1")
101+
df2.write.format("parquet").saveAsTable("normal_table1")
102+
df3.write.format("parquet").saveAsTable("normal_table2")
103+
// df1.write.format("parquet").saveAsTable("t3")
104+
105+
Seq(
106+
// (
107+
// """
108+
// |select * from (select i, j, k from
109+
// |(select distinct i, j, k from
110+
// |(select i, j, k from bucket_table1 cluster by j)
111+
// | t2 ) t1 cluster by j) t0 order by i, j, k
112+
// |""".stripMargin, 1, 2),
113+
114+
// (
115+
// """
116+
// |select * from (select i, j, k from
117+
// |(select distinct i, j, k from bucket_table1
118+
// | t2 ) t1 cluster by j ) t0 order by i, j, k
119+
// |""".stripMargin, 1, 1),
120+
121+
// basic test for bucket table
122+
(
123+
"""
124+
|select l, i, j, k from
125+
|normal_table2 t0 left join (select distinct i, j, k from bucket_table1
126+
|) t1 on t0.l = t1.j
127+
|""".stripMargin, 2, 2),
128+
(
129+
"""
130+
|select l, i, j, k from
131+
|normal_table2 t0 left join (
132+
|select i, j, k, count(*) as c from bucket_table1 group by 1, 2, 3
133+
|) t1 on t0.l = t1.c
134+
|""".stripMargin, 2, 2), // No support since clustered on c instead of i, j, k
135+
// with filter
136+
(
137+
"""
138+
|select l, i, j, k from
139+
|normal_table2 t0 left join (select distinct i, j, k from bucket_table1
140+
|where i > 0) t1 on t0.l = t1.j
141+
|""".stripMargin, 2, 2),
142+
// with one aggregate function
143+
(
144+
"""
145+
|select l, i, j, k, mi from
146+
|normal_table2 t0 left join (select distinct i, j, k, max(i) as mi from bucket_table1
147+
|where i > 0 group by 1, 2, 3) t1 on t0.l = t1.j
148+
|""".stripMargin, 2, 2),
149+
// with 2 aggregate function
150+
(
151+
"""
152+
|select l, i, j, k, mi, cnt from
153+
|normal_table2 t0 left join
154+
|(select distinct i, j, k, max(i) as mi, count(1) as cnt from bucket_table1
155+
|where i > 0 group by 1, 2, 3) t1 on t0.l = t1.j
156+
|""".stripMargin, 2, 2),
157+
158+
159+
// basic test for normal table
160+
(
161+
"""
162+
|select l, i, j, k from
163+
|normal_table2 t0 left join (select distinct i, j, k from normal_table1
164+
|) t1 on t0.l = t1.j
165+
|""".stripMargin, 2, 3),
166+
// with filter
167+
(
168+
"""
169+
|select l, i, j, k from
170+
|normal_table2 t0 left join (select distinct i, j, k from normal_table1
171+
|where i > 0) t1 on t0.l = t1.j
172+
|""".stripMargin, 2, 3),
173+
// with one aggregate function
174+
(
175+
"""
176+
|select l, i, j, k, mi from
177+
|normal_table2 t0 left join (select distinct i, j, k, max(i) as mi from normal_table1
178+
|where i > 0 group by 1, 2, 3) t1 on t0.l = t1.j
179+
|""".stripMargin, 2, 3),
180+
// with 2 aggregate function
181+
(
182+
"""
183+
|select l, i, j, k, mi, cnt from
184+
|normal_table2 t0 left join
185+
|(select distinct i, j, k, max(i) as mi, count(1) as cnt from normal_table1
186+
|where i > 0 group by 1, 2, 3) t1 on t0.l = t1.j
187+
|""".stripMargin, 2, 3)
188+
).foreach { case (query, shuffleCountIfEnable, shuffleCountIfDisable) =>
189+
withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "true",
190+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
191+
checkExchangePushDown(query, shuffleCountIfEnable, shuffleCountIfDisable)
192+
}
193+
}
194+
}
195+
}
196+
}

0 commit comments

Comments
 (0)