@@ -21,7 +21,7 @@ import scala.collection.mutable
2121
2222import org .apache .commons .io .FileUtils
2323
24- import org .apache .spark .{MapOutputTrackerMaster , SparkEnv }
24+ import org .apache .spark .{MapOutputStatistics , MapOutputTrackerMaster , SparkEnv }
2525import org .apache .spark .sql .catalyst .plans ._
2626import org .apache .spark .sql .catalyst .rules .Rule
2727import org .apache .spark .sql .execution ._
@@ -70,9 +70,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
7070 size > conf.getConf(SQLConf .SKEW_JOIN_SKEWED_PARTITION_THRESHOLD )
7171 }
7272
73- private def medianSize (sizes : Seq [ Long ] ): Long = {
74- val numPartitions = sizes .length
75- val bytes = sizes .sorted
73+ private def medianSize (stats : MapOutputStatistics ): Long = {
74+ val numPartitions = stats.bytesByPartitionId .length
75+ val bytes = stats.bytesByPartitionId .sorted
7676 numPartitions match {
7777 case _ if (numPartitions % 2 == 0 ) =>
7878 math.max((bytes(numPartitions / 2 ) + bytes(numPartitions / 2 - 1 )) / 2 , 1 )
@@ -163,16 +163,16 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
163163 if supportedJoinTypes.contains(joinType) =>
164164 assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
165165 val numPartitions = left.partitionsWithSizes.length
166- // Use the median size of the actual (coalesced) partition sizes to detect skewed partitions.
167- val leftMedSize = medianSize(left.partitionsWithSizes.map(_._2) )
168- val rightMedSize = medianSize(right.partitionsWithSizes.map(_._2) )
166+ // We use the median size of the original shuffle partitions to detect skewed partitions.
167+ val leftMedSize = medianSize(left.mapStats )
168+ val rightMedSize = medianSize(right.mapStats )
169169 logDebug(
170170 s """
171171 |Optimizing skewed join.
172172 |Left side partitions size info:
173- | ${getSizeInfo(leftMedSize, left.partitionsWithSizes.map(_._2) )}
173+ | ${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId )}
174174 |Right side partitions size info:
175- | ${getSizeInfo(rightMedSize, right.partitionsWithSizes.map(_._2) )}
175+ | ${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId )}
176176 """ .stripMargin)
177177 val canSplitLeft = canSplitLeftSide(joinType)
178178 val canSplitRight = canSplitRightSide(joinType)
@@ -291,15 +291,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
291291private object ShuffleStage {
292292 def unapply (plan : SparkPlan ): Option [ShuffleStageInfo ] = plan match {
293293 case s : ShuffleQueryStageExec if s.mapStats.isDefined =>
294- val sizes = s.mapStats.get.bytesByPartitionId
294+ val mapStats = s.mapStats.get
295+ val sizes = mapStats.bytesByPartitionId
295296 val partitions = sizes.zipWithIndex.map {
296297 case (size, i) => CoalescedPartitionSpec (i, i + 1 ) -> size
297298 }
298- Some (ShuffleStageInfo (s, partitions))
299+ Some (ShuffleStageInfo (s, mapStats, partitions))
299300
300301 case CustomShuffleReaderExec (s : ShuffleQueryStageExec , partitionSpecs)
301302 if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
302- val sizes = s.mapStats.get.bytesByPartitionId
303+ val mapStats = s.mapStats.get
304+ val sizes = mapStats.bytesByPartitionId
303305 val partitions = partitionSpecs.map {
304306 case spec @ CoalescedPartitionSpec (start, end) =>
305307 var sum = 0L
@@ -312,12 +314,13 @@ private object ShuffleStage {
312314 case other => throw new IllegalArgumentException (
313315 s " Expect CoalescedPartitionSpec but got $other" )
314316 }
315- Some (ShuffleStageInfo (s, partitions))
317+ Some (ShuffleStageInfo (s, mapStats, partitions))
316318
317319 case _ => None
318320 }
319321}
320322
321323private case class ShuffleStageInfo (
322324 shuffleStage : ShuffleQueryStageExec ,
325+ mapStats : MapOutputStatistics ,
323326 partitionsWithSizes : Seq [(CoalescedPartitionSpec , Long )])
0 commit comments