@@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe
2626import org .apache .spark .sql .{QueryTest , Row , SparkSession , Strategy }
2727import org .apache .spark .sql .catalyst .optimizer .{BuildLeft , BuildRight }
2828import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan }
29- import org .apache .spark .sql .execution .{PartialReducerPartitionSpec , ReusedSubqueryExec , ShuffledRowRDD , SparkPlan }
29+ import org .apache .spark .sql .execution .{PartialReducerPartitionSpec , ReusedSubqueryExec , ShuffledRowRDD , SparkPlan , UnionExec }
3030import org .apache .spark .sql .execution .command .DataWritingCommandExec
3131import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , Exchange , ReusedExchangeExec }
3232import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , SortMergeJoinExec }
@@ -719,6 +719,105 @@ class AdaptiveQueryExecSuite
719719 }
720720 }
721721
722+ test(" SPARK-32129: adaptive skew join with union" ) {
723+ withSQLConf(
724+ SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
725+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
726+ SQLConf .COALESCE_PARTITIONS_MIN_PARTITION_NUM .key -> " 1" ,
727+ SQLConf .SHUFFLE_PARTITIONS .key -> " 100" ,
728+ SQLConf .SKEW_JOIN_SKEWED_PARTITION_THRESHOLD .key -> " 800" ,
729+ SQLConf .ADVISORY_PARTITION_SIZE_IN_BYTES .key -> " 800" ) {
730+ withTempView(" skewData1" , " skewData2" , " skewData3" , " skewData4" ) {
731+ spark
732+ .range(0 , 1000 , 1 , 10 )
733+ .select(
734+ when(' id < 250 , 249 )
735+ .when(' id >= 750 , 1000 )
736+ .otherwise(' id ).as(" key1" ),
737+ ' id as " value1" )
738+ .createOrReplaceTempView(" skewData1" )
739+ spark
740+ .range(0 , 1000 , 1 , 10 )
741+ .select(
742+ when(' id < 250 , 249 )
743+ .otherwise(' id ).as(" key2" ),
744+ ' id as " value2" )
745+ .createOrReplaceTempView(" skewData2" )
746+ spark
747+ .range(0 , 1000 , 1 , 10 )
748+ .select(
749+ when(' id < 250 , 249 )
750+ .when(' id >= 750 , 1000 )
751+ .otherwise(' id ).as(" key3" ),
752+ ' id as " value3" )
753+ .createOrReplaceTempView(" skewData3" )
754+ spark
755+ .range(0 , 1000 , 1 , 10 )
756+ .select(
757+ when(' id < 250 , 249 )
758+ .otherwise(' id ).as(" key4" ),
759+ ' id as " value4" )
760+ .createOrReplaceTempView(" skewData4" )
761+
762+ def checkSkewJoin (
763+ joins : Seq [SortMergeJoinExec ],
764+ leftSkewNum : Int ,
765+ rightSkewNum : Int ): Unit = {
766+ assert(joins.size == 1 && joins.head.isSkewJoin)
767+ assert(joins.head.left.collect {
768+ case r : CustomShuffleReaderExec => r
769+ }.head.partitionSpecs.collect {
770+ case p : PartialReducerPartitionSpec => p.reducerIndex
771+ }.distinct.length == leftSkewNum)
772+ assert(joins.head.right.collect {
773+ case r : CustomShuffleReaderExec => r
774+ }.head.partitionSpecs.collect {
775+ case p : PartialReducerPartitionSpec => p.reducerIndex
776+ }.distinct.length == rightSkewNum)
777+ }
778+
779+ // skewed inner join optimization with union (not union all)
780+ val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
781+ " SELECT * FROM skewData1 join skewData2 ON key1 = key2 " +
782+ " UNION SELECT * FROM skewData3 join skewData4 ON key3 = key4" )
783+ innerAdaptivePlan transformUp {
784+ case u@ UnionExec (c) =>
785+ val innerSmj1 = findTopLevelSortMergeJoin(c.head)
786+ val innerSmj2 = findTopLevelSortMergeJoin(c.tail.head)
787+ checkSkewJoin(innerSmj1, 2 , 1 )
788+ checkSkewJoin(innerSmj2, 2 , 1 )
789+ u
790+ }
791+
792+ // skewed left outer join optimization with union all
793+ val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
794+ " SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2 " +
795+ " UNION ALL SELECT * FROM skewData3 left outer join skewData4 ON key3 = key4" )
796+ leftAdaptivePlan transformUp {
797+ case u@ UnionExec (c) =>
798+ val leftSmj1 = findTopLevelSortMergeJoin(c.head)
799+ val leftSmj2 = findTopLevelSortMergeJoin(c.tail.head)
800+ checkSkewJoin(leftSmj1, 2 , 0 )
801+ checkSkewJoin(leftSmj2, 2 , 0 )
802+ u
803+ }
804+
805+ // skewed right outer join optimization with union all
806+ val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
807+ " SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2 " +
808+ " UNION ALL SELECT * FROM skewData3 right outer join skewData4 ON key3 = key4" )
809+ rightAdaptivePlan transformUp {
810+ case u@ UnionExec (c) =>
811+ val rightSmj1 = findTopLevelSortMergeJoin(c.head)
812+ val rightSmj2 = findTopLevelSortMergeJoin(c.tail.head)
813+ checkSkewJoin(rightSmj1, 0 , 1 )
814+ checkSkewJoin(rightSmj2, 0 , 1 )
815+ u
816+ }
817+ }
818+ }
819+ }
820+
722821 test(" SPARK-30291: AQE should catch the exceptions when doing materialize" ) {
723822 withSQLConf(
724823 SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ) {
0 commit comments