diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 396c9c9d6b4e5..f348b1cb86e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -250,10 +250,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { } } - override def apply(plan: SparkPlan): SparkPlan = { - if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) { - return plan - } + private def tryOptimize(plan: SparkPlan): SparkPlan = { def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { case stage: ShuffleQueryStageExec => Seq(stage) @@ -286,6 +283,36 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { plan } } + + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) { + return plan + } + + // Try to handle skew join with union case, like + // Union + // SMJ + // Sort + // Shuffle + // Sort + // Shuffle + // SMJ + // Sort + // Shuffle + // Sort + // Shuffle + var containsUnion = false + val optimizedUnion = plan transformUp { + case u @ UnionExec(children) => + containsUnion = true + u.withNewChildren(children.map(tryOptimize)) + } + if (containsUnion) { + optimizedUnion + } else { + tryOptimize(plan) + } + } } private object ShuffleStage { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 9fa97bffa8910..cf34a6a856f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} +import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnionExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} @@ -719,6 +719,105 @@ class AdaptiveQueryExecSuite } } + test("SPARK-32129: adaptive skew join with union") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + withTempView("skewData1", "skewData2", "skewData3", "skewData4") { + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key3"), + 'id as "value3") + .createOrReplaceTempView("skewData3") + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key4"), + 'id as "value4") + .createOrReplaceTempView("skewData4") + + def checkSkewJoin( + joins: Seq[SortMergeJoinExec], + leftSkewNum: Int, + rightSkewNum: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: CustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == leftSkewNum) + assert(joins.head.right.collect { + case r: CustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == rightSkewNum) + } + + // skewed inner join optimization with union (not union all) + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 join skewData2 ON key1 = key2 " + + "UNION SELECT * FROM skewData3 join skewData4 ON key3 = key4") + innerAdaptivePlan transformUp { + case u@UnionExec(c) => + val innerSmj1 = findTopLevelSortMergeJoin(c.head) + val innerSmj2 = findTopLevelSortMergeJoin(c.tail.head) + checkSkewJoin(innerSmj1, 2, 1) + checkSkewJoin(innerSmj2, 2, 1) + u + } + + // skewed left outer join optimization with union all + val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2 " + + "UNION ALL SELECT * FROM skewData3 left outer join skewData4 ON key3 = key4") + leftAdaptivePlan transformUp { + case u@UnionExec(c) => + val leftSmj1 = findTopLevelSortMergeJoin(c.head) + val leftSmj2 = findTopLevelSortMergeJoin(c.tail.head) + checkSkewJoin(leftSmj1, 2, 0) + checkSkewJoin(leftSmj2, 2, 0) + u + } + + // skewed right outer join optimization with union all + val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2 " + + "UNION ALL SELECT * FROM skewData3 right outer join skewData4 ON key3 = key4") + rightAdaptivePlan transformUp { + case u@UnionExec(c) => + val rightSmj1 = findTopLevelSortMergeJoin(c.head) + val rightSmj2 = findTopLevelSortMergeJoin(c.tail.head) + checkSkewJoin(rightSmj1, 0, 1) + checkSkewJoin(rightSmj2, 0, 1) + u + } + } + } + } + test("SPARK-30291: AQE should catch the exceptions when doing materialize") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {