Skip to content

Commit dd09e70

Browse files
committed
Add a new parititioning CoalescedHashPartitioning
1 parent 80bef0d commit dd09e70

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,28 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
340340
case _ => false
341341
}
342342
}
343+
344+
/**
345+
* With AE, multiple partitions in hash partitioned output could be coalesced
346+
* to a single partition. CoalescedHashPartitioning is designed for such case.
347+
*/
348+
case class CoalescedHashPartitioning(
349+
expressions: Seq[Expression],
350+
numPartitions: Int)
351+
extends Expression with Partitioning with Unevaluable {
352+
353+
override def children: Seq[Expression] = expressions
354+
override def nullable: Boolean = false
355+
override def dataType: DataType = IntegerType
356+
357+
override def satisfies0(required: Distribution): Boolean = {
358+
super.satisfies0(required) || {
359+
required match {
360+
case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
361+
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
362+
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
363+
case _ => false
364+
}
365+
}
366+
}
367+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
26-
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
26+
import org.apache.spark.sql.catalyst.plans.physical.{CoalescedHashPartitioning, HashPartitioning, Partitioning, UnknownPartitioning}
2727
import org.apache.spark.sql.execution._
2828
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
2929
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -65,6 +65,14 @@ case class CustomShuffleReaderExec private(
6565
case _ =>
6666
throw new IllegalStateException("operating on canonicalization plan")
6767
}
68+
} else if (partitionSpecs.nonEmpty &&
69+
partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec])) {
70+
child match {
71+
case ShuffleQueryStageExec(_, ShuffleExchangeExec(p: HashPartitioning, _, _)) =>
72+
CoalescedHashPartitioning(p.expressions, partitionSpecs.size)
73+
case _ =>
74+
throw new IllegalStateException("operating on canonicalization plan")
75+
}
6876
} else {
6977
UnknownPartitioning(partitionSpecs.length)
7078
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,24 @@ class AdaptiveQueryExecSuite
738738
SQLConf.SHUFFLE_PARTITIONS.key -> "100",
739739
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
740740
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
741+
742+
// SMJ
743+
// Sort
744+
// CustomShuffleReader(coalesced)
745+
// Shuffle
746+
// Sort
747+
// HashAggregate
748+
// CustomShuffleReader(coalesced)
749+
// Shuffle
750+
// -->
751+
// SMJ
752+
// Sort
753+
// CustomShuffleReader(coalesced and skew)
754+
// Shuffle
755+
// Sort
756+
// HashAggregate
757+
// CustomShuffleReader(coalesced)
758+
// Shuffle
741759
withTempView("skewData1", "skewData2") {
742760
spark
743761
.range(0, 1000, 1, 10)
@@ -747,14 +765,14 @@ class AdaptiveQueryExecSuite
747765
.otherwise('id).as("key1"),
748766
'id as "value1")
749767
.createOrReplaceTempView("skewData1")
750-
751768
spark
752769
.range(0, 1000, 1, 10)
753770
.select(
754771
when('id < 250, 249)
755772
.otherwise('id).as("key2"),
756773
'id as "value2")
757774
.createOrReplaceTempView("skewData2")
775+
758776
val sqlText =
759777
"""
760778
|SELECT * FROM
@@ -764,7 +782,7 @@ class AdaptiveQueryExecSuite
764782
| SELECT skewData2.key2, sum(skewData2.value2) AS sum2
765783
| FROM skewData2 GROUP BY skewData2.key2
766784
| ) AS data2
767-
|ON data1.key1 = data2.key2
785+
|ON data1.key1 = data2.key2 LIMIT 10
768786
|""".stripMargin
769787

770788
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(sqlText)

0 commit comments

Comments
 (0)