Skip to content

Commit 5d78537

Browse files
manuzhangcloud-fan
authored andcommitted
[SPARK-31942] Revert "[SPARK-31864][SQL] Adjust AQE skew join trigger condition
### What changes were proposed in this pull request? This reverts commit b9737c3 while keeping following changes * set default value of `spark.sql.adaptive.skewJoin.skewedPartitionFactor` to 5 * improve tests * remove unused imports ### Why are the changes needed? As discussed in #28669 (comment), revert SPARK-31864 for optimizing skew join to work for extremely clustered keys. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #28770 from manuzhang/spark-31942. Authored-by: manuzhang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 22dda6e commit 5d78537

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.commons.io.FileUtils
2323

24-
import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
24+
import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.execution._
@@ -70,9 +70,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
7070
size > conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD)
7171
}
7272

73-
private def medianSize(sizes: Seq[Long]): Long = {
74-
val numPartitions = sizes.length
75-
val bytes = sizes.sorted
73+
private def medianSize(stats: MapOutputStatistics): Long = {
74+
val numPartitions = stats.bytesByPartitionId.length
75+
val bytes = stats.bytesByPartitionId.sorted
7676
numPartitions match {
7777
case _ if (numPartitions % 2 == 0) =>
7878
math.max((bytes(numPartitions / 2) + bytes(numPartitions / 2 - 1)) / 2, 1)
@@ -163,16 +163,16 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
163163
if supportedJoinTypes.contains(joinType) =>
164164
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
165165
val numPartitions = left.partitionsWithSizes.length
166-
// Use the median size of the actual (coalesced) partition sizes to detect skewed partitions.
167-
val leftMedSize = medianSize(left.partitionsWithSizes.map(_._2))
168-
val rightMedSize = medianSize(right.partitionsWithSizes.map(_._2))
166+
// We use the median size of the original shuffle partitions to detect skewed partitions.
167+
val leftMedSize = medianSize(left.mapStats)
168+
val rightMedSize = medianSize(right.mapStats)
169169
logDebug(
170170
s"""
171171
|Optimizing skewed join.
172172
|Left side partitions size info:
173-
|${getSizeInfo(leftMedSize, left.partitionsWithSizes.map(_._2))}
173+
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
174174
|Right side partitions size info:
175-
|${getSizeInfo(rightMedSize, right.partitionsWithSizes.map(_._2))}
175+
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
176176
""".stripMargin)
177177
val canSplitLeft = canSplitLeftSide(joinType)
178178
val canSplitRight = canSplitRightSide(joinType)
@@ -291,15 +291,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
291291
private object ShuffleStage {
292292
def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
293293
case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
294-
val sizes = s.mapStats.get.bytesByPartitionId
294+
val mapStats = s.mapStats.get
295+
val sizes = mapStats.bytesByPartitionId
295296
val partitions = sizes.zipWithIndex.map {
296297
case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
297298
}
298-
Some(ShuffleStageInfo(s, partitions))
299+
Some(ShuffleStageInfo(s, mapStats, partitions))
299300

300301
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs)
301302
if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
302-
val sizes = s.mapStats.get.bytesByPartitionId
303+
val mapStats = s.mapStats.get
304+
val sizes = mapStats.bytesByPartitionId
303305
val partitions = partitionSpecs.map {
304306
case spec @ CoalescedPartitionSpec(start, end) =>
305307
var sum = 0L
@@ -312,12 +314,13 @@ private object ShuffleStage {
312314
case other => throw new IllegalArgumentException(
313315
s"Expect CoalescedPartitionSpec but got $other")
314316
}
315-
Some(ShuffleStageInfo(s, partitions))
317+
Some(ShuffleStageInfo(s, mapStats, partitions))
316318

317319
case _ => None
318320
}
319321
}
320322

321323
private case class ShuffleStageInfo(
322324
shuffleStage: ShuffleQueryStageExec,
325+
mapStats: MapOutputStatistics,
323326
partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])

0 commit comments

Comments
 (0)