Skip to content

Commit d011e9a

Browse files
committed
[SPARK-32129][SQL] Support AQE skew join with Union
1 parent e29ec42 commit d011e9a

File tree

2 files changed

+131
-5
lines changed

2 files changed

+131
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
250250
}
251251
}
252252

253-
override def apply(plan: SparkPlan): SparkPlan = {
254-
if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) {
255-
return plan
256-
}
253+
private def tryOptimize(plan: SparkPlan): SparkPlan = {
257254

258255
def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
259256
case stage: ShuffleQueryStageExec => Seq(stage)
@@ -286,6 +283,36 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
286283
plan
287284
}
288285
}
286+
287+
override def apply(plan: SparkPlan): SparkPlan = {
288+
if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) {
289+
return plan
290+
}
291+
292+
// Try to handle skew join with union case, like
293+
// Union
294+
// SMJ
295+
// Sort
296+
// Shuffle
297+
// Sort
298+
// Shuffle
299+
// SMJ
300+
// Sort
301+
// Shuffle
302+
// Sort
303+
// Shuffle
304+
var containsUnion = false
305+
val optimizedUnion = plan transformUp {
306+
case u @ UnionExec(children) =>
307+
containsUnion = true
308+
u.withNewChildren(children.map(tryOptimize))
309+
}
310+
if (containsUnion) {
311+
optimizedUnion
312+
} else {
313+
tryOptimize(plan)
314+
}
315+
}
289316
}
290317

291318
private object ShuffleStage {

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe
2626
import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
2727
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
2828
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
29-
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
29+
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnionExec}
3030
import org.apache.spark.sql.execution.command.DataWritingCommandExec
3131
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec}
3232
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -719,6 +719,105 @@ class AdaptiveQueryExecSuite
719719
}
720720
}
721721

722+
test("SPARK-32129: adaptive skew join with union") {
723+
withSQLConf(
724+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
725+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
726+
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
727+
SQLConf.SHUFFLE_PARTITIONS.key -> "100",
728+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
729+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
730+
withTempView("skewData1", "skewData2", "skewData3", "skewData4") {
731+
spark
732+
.range(0, 1000, 1, 10)
733+
.select(
734+
when('id < 250, 249)
735+
.when('id >= 750, 1000)
736+
.otherwise('id).as("key1"),
737+
'id as "value1")
738+
.createOrReplaceTempView("skewData1")
739+
spark
740+
.range(0, 1000, 1, 10)
741+
.select(
742+
when('id < 250, 249)
743+
.otherwise('id).as("key2"),
744+
'id as "value2")
745+
.createOrReplaceTempView("skewData2")
746+
spark
747+
.range(0, 1000, 1, 10)
748+
.select(
749+
when('id < 250, 249)
750+
.when('id >= 750, 1000)
751+
.otherwise('id).as("key3"),
752+
'id as "value3")
753+
.createOrReplaceTempView("skewData3")
754+
spark
755+
.range(0, 1000, 1, 10)
756+
.select(
757+
when('id < 250, 249)
758+
.otherwise('id).as("key4"),
759+
'id as "value4")
760+
.createOrReplaceTempView("skewData4")
761+
762+
def checkSkewJoin(
763+
joins: Seq[SortMergeJoinExec],
764+
leftSkewNum: Int,
765+
rightSkewNum: Int): Unit = {
766+
assert(joins.size == 1 && joins.head.isSkewJoin)
767+
assert(joins.head.left.collect {
768+
case r: CustomShuffleReaderExec => r
769+
}.head.partitionSpecs.collect {
770+
case p: PartialReducerPartitionSpec => p.reducerIndex
771+
}.distinct.length == leftSkewNum)
772+
assert(joins.head.right.collect {
773+
case r: CustomShuffleReaderExec => r
774+
}.head.partitionSpecs.collect {
775+
case p: PartialReducerPartitionSpec => p.reducerIndex
776+
}.distinct.length == rightSkewNum)
777+
}
778+
779+
// skewed inner join optimization with union (not union all)
780+
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
781+
"SELECT * FROM skewData1 join skewData2 ON key1 = key2 " +
782+
"UNION SELECT * FROM skewData3 join skewData4 ON key3 = key4")
783+
innerAdaptivePlan transformUp {
784+
case u@UnionExec(c) =>
785+
val innerSmj1 = findTopLevelSortMergeJoin(c.head)
786+
val innerSmj2 = findTopLevelSortMergeJoin(c.tail.head)
787+
checkSkewJoin(innerSmj1, 2, 1)
788+
checkSkewJoin(innerSmj2, 2, 1)
789+
u
790+
}
791+
792+
// skewed left outer join optimization with union all
793+
val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
794+
"SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2 " +
795+
"UNION ALL SELECT * FROM skewData3 left outer join skewData4 ON key3 = key4")
796+
leftAdaptivePlan transformUp {
797+
case u@UnionExec(c) =>
798+
val leftSmj1 = findTopLevelSortMergeJoin(c.head)
799+
val leftSmj2 = findTopLevelSortMergeJoin(c.tail.head)
800+
checkSkewJoin(leftSmj1, 2, 0)
801+
checkSkewJoin(leftSmj2, 2, 0)
802+
u
803+
}
804+
805+
// skewed right outer join optimization with union all
806+
val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
807+
"SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2 " +
808+
"UNION ALL SELECT * FROM skewData3 right outer join skewData4 ON key3 = key4")
809+
rightAdaptivePlan transformUp {
810+
case u@UnionExec(c) =>
811+
val rightSmj1 = findTopLevelSortMergeJoin(c.head)
812+
val rightSmj2 = findTopLevelSortMergeJoin(c.tail.head)
813+
checkSkewJoin(rightSmj1, 0, 1)
814+
checkSkewJoin(rightSmj2, 0, 1)
815+
u
816+
}
817+
}
818+
}
819+
}
820+
722821
test("SPARK-30291: AQE should catch the exceptions when doing materialize") {
723822
withSQLConf(
724823
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {

0 commit comments

Comments
 (0)