diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 3620f27058af2..7a70a0c245bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -142,6 +142,14 @@ object OptimizeLocalShuffleReader { def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match { case s: ShuffleQueryStageExec => s.shuffle.canChangeNumPartitions + // This CustomShuffleReaderExec used in skew side, its numPartitions increased. + case CustomShuffleReaderExec(_, partitionSpecs) + if partitionSpecs.exists(_.isInstanceOf[PartialReducerPartitionSpec]) => false + // This CustomShuffleReaderExec used in non-skew side, its numPartitions equals to + // the skew side CustomShuffleReaderExec. + case CustomShuffleReaderExec(_, partitionSpecs) if partitionSpecs.size > 1 && + partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec]) && + partitionSpecs.toSet.size != partitionSpecs.size => false case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) => s.shuffle.canChangeNumPartitions && partitionSpecs.nonEmpty case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 627f0600f2383..7820df9c9a2bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} @@ -130,20 +131,45 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { } } - private def canSplitLeftSide(joinType: JoinType) = { - joinType == Inner || joinType == Cross || joinType == LeftSemi || - joinType == LeftAnti || joinType == LeftOuter + private def canSplitLeftSide(joinType: JoinType, plan: SparkPlan) = { + (joinType == Inner || joinType == Cross || joinType == LeftSemi || + joinType == LeftAnti || joinType == LeftOuter) && allUnspecifiedDistribution(plan) } - private def canSplitRightSide(joinType: JoinType) = { - joinType == Inner || joinType == Cross || joinType == RightOuter + private def canSplitRightSide(joinType: JoinType, plan: SparkPlan) = { + (joinType == Inner || joinType == Cross || + joinType == RightOuter) && allUnspecifiedDistribution(plan) } + // Check if there is a node in the tree that the requiredChildDistribution is specified, + // other than UnspecifiedDistribution. + private def allUnspecifiedDistribution(plan: SparkPlan): Boolean = plan.find { p => + p.requiredChildDistribution.exists { + case UnspecifiedDistribution => false + case _ => true + } + }.isEmpty + private def getSizeInfo(medianSize: Long, sizes: Seq[Long]): String = { s"median size: $medianSize, max size: ${sizes.max}, min size: ${sizes.min}, avg size: " + sizes.sum / sizes.length } + private def findShuffleStage(plan: SparkPlan): Option[ShuffleStageInfo] = { + plan collectFirst { + case _ @ ShuffleStage(shuffleStageInfo) => + shuffleStageInfo + } + } + + private def replaceSkewedShufleReader( + smj: SparkPlan, newCtm: CustomShuffleReaderExec): SparkPlan = { + smj transformUp { + case _ @ CustomShuffleReaderExec(child, _) if child.sameResult(newCtm.child) => + newCtm + } + } + /* * This method aim to optimize the skewed join with the following steps: * 1. Check whether the shuffle partition is skewed based on the median size @@ -157,96 +183,106 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * 3 tasks separately. */ def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(_, _, joinType, _, - s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _), - s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _) + case smj @ SortMergeJoinExec(_, _, joinType, _, s1: SortExec, s2: SortExec, _) if supportedJoinTypes.contains(joinType) => - assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length) - val numPartitions = left.partitionsWithSizes.length - // We use the median size of the original shuffle partitions to detect skewed partitions. - val leftMedSize = medianSize(left.mapStats) - val rightMedSize = medianSize(right.mapStats) - logDebug( - s""" - |Optimizing skewed join. - |Left side partitions size info: - |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)} - |Right side partitions size info: - |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)} - """.stripMargin) - val canSplitLeft = canSplitLeftSide(joinType) - val canSplitRight = canSplitRightSide(joinType) - // We use the actual partition sizes (may be coalesced) to calculate target size, so that - // the final data distribution is even (coalesced partitions + split partitions). - val leftActualSizes = left.partitionsWithSizes.map(_._2) - val rightActualSizes = right.partitionsWithSizes.map(_._2) - val leftTargetSize = targetSize(leftActualSizes, leftMedSize) - val rightTargetSize = targetSize(rightActualSizes, rightMedSize) - - val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - var numSkewedLeft = 0 - var numSkewedRight = 0 - for (partitionIndex <- 0 until numPartitions) { - val leftActualSize = leftActualSizes(partitionIndex) - val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft - val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1 - val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex - - val rightActualSize = rightActualSizes(partitionIndex) - val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight - val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1 - val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex - - // A skewed partition should never be coalesced, but skip it here just to be safe. - val leftParts = if (isLeftSkew && !isLeftCoalesced) { - val reducerId = leftPartSpec.startReducerIndex - val skewSpecs = createSkewPartitionSpecs( - left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Left side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedLeft += 1 + // find the shuffleStage from the plan tree + val leftOpt = findShuffleStage(s1) + val rightOpt = findShuffleStage(s2) + if (leftOpt.isEmpty || rightOpt.isEmpty) { + smj + } else { + val left = leftOpt.get + val right = rightOpt.get + assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length) + val numPartitions = left.partitionsWithSizes.length + // We use the median size of the original shuffle partitions to detect skewed partitions. + val leftMedSize = medianSize(left.mapStats) + val rightMedSize = medianSize(right.mapStats) + logDebug( + s""" + |Optimizing skewed join. + |Left side partitions size info: + |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)} + + |Right side partitio + + |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)} + """.stripMargin) + val canSplitLeft = canSplitLeftSide(joinType, s1) + val canSplitRight = canSplitRightSide(joinType, s2) + // We use the actual partition sizes (may be coalesced) to calculate target size, so that + // the final data distribution is even (coalesced partitions + split partitions). + val leftActualSizes = left.partitionsWithSizes.map(_._2) + val rightActualSizes = right.partitionsWithSizes.map(_._2) + val leftTargetSize = targetSize(leftActualSizes, leftMedSize) + val rightTargetSize = targetSize(rightActualSizes, rightMedSize) + + val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + var numSkewedLeft = 0 + var numSkewedRight = 0 + for (partitionIndex <- 0 until numPartitions) { + val leftActualSize = leftActualSizes(partitionIndex) + val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft + val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1 + val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex + + val rightActualSize = rightActualSizes(partitionIndex) + val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight + val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1 + val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex + + // A skewed partition should never be coalesced, but skip it here just to be safe. + val leftParts = if (isLeftSkew && !isLeftCoalesced) { + val reducerId = leftPartSpec.startReducerIndex + val skewSpecs = createSkewPartitionSpecs( + left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " + + s"split it into ${skewSpecs.get.length} parts.") + numSkewedLeft += 1 + } + skewSpecs.getOrElse(Seq(leftPartSpec)) + } else { + Seq(leftPartSpec) } - skewSpecs.getOrElse(Seq(leftPartSpec)) - } else { - Seq(leftPartSpec) - } - // A skewed partition should never be coalesced, but skip it here just to be safe. - val rightParts = if (isRightSkew && !isRightCoalesced) { - val reducerId = rightPartSpec.startReducerIndex - val skewSpecs = createSkewPartitionSpecs( - right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Right side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedRight += 1 + // A skewed partition should never be coalesced, but skip it here just to be safe. + val rightParts = if (isRightSkew && !isRightCoalesced) { + val reducerId = rightPartSpec.startReducerIndex + val skewSpecs = createSkewPartitionSpecs( + right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " + + s"split it into ${skewSpecs.get.length} parts.") + numSkewedRight += 1 + } + skewSpecs.getOrElse(Seq(rightPartSpec)) + } else { + Seq(rightPartSpec) } - skewSpecs.getOrElse(Seq(rightPartSpec)) - } else { - Seq(rightPartSpec) - } - for { - leftSidePartition <- leftParts - rightSidePartition <- rightParts - } { - leftSidePartitions += leftSidePartition - rightSidePartitions += rightSidePartition + for { + leftSidePartition <- leftParts + rightSidePartition <- rightParts + } { + leftSidePartitions += leftSidePartition + rightSidePartitions += rightSidePartition + } } - } - logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") - if (numSkewedLeft > 0 || numSkewedRight > 0) { - val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions.toSeq) - val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions.toSeq) - smj.copy( - left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) - } else { - smj + logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") + if (numSkewedLeft > 0 || numSkewedRight > 0) { + val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions.toSeq) + val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions.toSeq) + val newSmj = replaceSkewedShufleReader( + replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec] + newSmj.copy(isSkewJoin = true) + } else { + smj + } } } @@ -263,18 +299,31 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val shuffleStages = collectShuffleStages(plan) if (shuffleStages.length == 2) { - // When multi table join, there will be too many complex combination to consider. - // Currently we only handle 2 table join like following use case. + // SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes, + // includes such as BroadcastHashJoinExec. So it can handle more than two tables join. // SMJ // Sort - // Shuffle + // .. + // Shuffle // Sort - // Shuffle + // .. + // Shuffle val optimizePlan = optimizeSkewJoin(plan) - val numShuffles = ensureRequirements.apply(optimizePlan).collect { - case e: ShuffleExchangeExec => e - }.length + def countAdditionalShuffleInAncestorsOfSkewJoin(optimizePlan: SparkPlan): Int = { + val newPlan = ensureRequirements.apply(optimizePlan) + val totalAdditionalShuffles = newPlan.collect { case e: ShuffleExchangeExec => e }.size + val numShufflesFromDescendants = + newPlan.collectFirst { case j: SortMergeJoinExec if j.isSkewJoin => j }.map { smj => + smj.collect { case e: ShuffleExchangeExec => e }.size + }.getOrElse(0) + totalAdditionalShuffles - numShufflesFromDescendants + } + + // Check if we introduced new shuffles in the ancestors of the skewed join operator. + // And we don't care if new shuffles are introduced in the descendants of the join operator, + // since they will not actually be executed in the current adaptive execution framework. + val numShuffles = countAdditionalShuffleInAncestorsOfSkewJoin(optimizePlan) if (numShuffles > 0) { logDebug("OptimizeSkewedJoin rule is not applied due" + " to additional shuffles will be introduced.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b9f6684447dd8..c848f7ea62b36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -82,6 +82,14 @@ case class SortMergeJoinExec( } } + override def outputPartitioning: Partitioning = { + if (isSkewJoin) { + UnknownPartitioning(0) + } else { + super.outputPartitioning + } + } + override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. case _: InnerLike => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 03471fb047260..0ae35278643fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -693,23 +693,6 @@ class AdaptiveQueryExecSuite 'id as "value2") .createOrReplaceTempView("skewData2") - def checkSkewJoin( - joins: Seq[SortMergeJoinExec], - leftSkewNum: Int, - rightSkewNum: Int): Unit = { - assert(joins.size == 1 && joins.head.isSkewJoin) - assert(joins.head.left.collect { - case r: CustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == leftSkewNum) - assert(joins.head.right.collect { - case r: CustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == rightSkewNum) - } - // skewed inner join optimization val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") @@ -731,6 +714,130 @@ class AdaptiveQueryExecSuite } } + private def checkSkewJoin( + joins: Seq[SortMergeJoinExec], + leftSkewNum: Int, + rightSkewNum: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: CustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == leftSkewNum) + assert(joins.head.right.collect { + case r: CustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == rightSkewNum) + } + + test("SPARK-32201: handle general skew join pattern") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + + // CASE 1: + // SMJ + // Sort + // CustomShuffleReader(coalesced) + // Shuffle + // Sort + // HashAggregate + // CustomShuffleReader(coalesced) + // Shuffle + // --> + // SMJ + // Sort + // CustomShuffleReader(coalesced and skew) + // Shuffle + // Sort + // HashAggregate + // CustomShuffleReader(coalesced) + // Shuffle + withTempView("skewData1", "skewData2", "smallData") { + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + + val sqlText = + """ + |SELECT * FROM + | skewData1 AS data1 + | INNER JOIN + | ( + | SELECT skewData2.key2, sum(skewData2.value2) AS sum2 + | FROM skewData2 GROUP BY skewData2.key2 + | ) AS data2 + |ON data1.key1 = data2.key2 LIMIT 10 + |""".stripMargin + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(sqlText) + val innerSmj = findTopLevelSortMergeJoin(adaptivePlan) + checkSkewJoin(innerSmj, 2, 0) + + // CASE 2: + // SMJ + // Sort + // SMJ + // CustomShuffleReader(coalesced) + // Shuffle + // Sort + // CustomShuffleReader(coalesced) + // Shuffle + // --> + // SMJ + // Sort + // BroadcastHashJoin <-- SMJ change to BCJ + // CustomShuffleReader(coalesced) + // Shuffle + // Sort + // CustomShuffleReader(coalesced and skew) + // Shuffle + spark + .range(0, 100, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key3"), + expr("concat(id, 'aaa')") as "value3") + .createOrReplaceTempView("smallData") + + val sqlText2 = + """ + |SELECT * FROM + | ( + | SELECT t1.* + | FROM skewData1 t1 LEFT JOIN smallData t2 + | ON t1.key1 = t2.key3 + | AND t2.value3 = 'xyz' + | ) AS data1 + | INNER JOIN + | skewData2 AS data2 + |ON data1.key1 = data2.key2 LIMIT 10 + |""".stripMargin + val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(sqlText2) + val innerSmj2 = findTopLevelSortMergeJoin(adaptivePlan2) + checkSkewJoin(innerSmj2, 0, 1) + } + } + } + test("SPARK-30291: AQE should catch the exceptions when doing materialize") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {