Skip to content

Commit 11d1f34

Browse files
committed
fix a bug that no support a union node with differing number of partitions if we explicitly repartition them apache#98
1 parent 64c3f6b commit 11d1f34

File tree

2 files changed

+53
-8
lines changed

2 files changed

+53
-8
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,33 @@ abstract class QueryStage extends UnaryExecNode {
8585
Future.sequence(shuffleStageFutures)(implicitly, QueryStage.executionContext), Duration.Inf)
8686
}
8787

88+
def getSupportAdaptiveFlag(queryStageInputs: Seq[ShuffleQueryStageInput]): Boolean = {
89+
val queryStageInputsNumPartitions = queryStageInputs.map {
90+
_.outputPartitioning match {
91+
case hash: HashPartitioning => hash.numPartitions
92+
case collection: PartitioningCollection =>
93+
val PartitioningCollectionNumPartitions = collection.partitionings.map {
94+
partitioning => {
95+
if (partitioning.isInstanceOf[HashPartitioning]) {
96+
partitioning.numPartitions
97+
} else {
98+
-1
99+
}
100+
}
101+
}.distinct
102+
if (PartitioningCollectionNumPartitions.length > 1) {
103+
-1
104+
} else {
105+
PartitioningCollectionNumPartitions.head
106+
}
107+
case _ => -1
108+
}
109+
}.distinct
110+
val supportAdaptiveFlag = (queryStageInputsNumPartitions.length == 1
111+
&& queryStageInputsNumPartitions.head != -1)
112+
supportAdaptiveFlag
113+
}
114+
88115
private var prepared = false
89116

90117
/**
@@ -127,14 +154,7 @@ abstract class QueryStage extends UnaryExecNode {
127154
val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics)
128155
.filter(_ != null).toArray
129156
// Right now, Adaptive execution only support HashPartitionings.
130-
val supportAdaptive = queryStageInputs.forall {
131-
_.outputPartitioning match {
132-
case hash: HashPartitioning => true
133-
case collection: PartitioningCollection =>
134-
collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
135-
case _ => false
136-
}
137-
}
157+
val supportAdaptive = getSupportAdaptiveFlag(queryStageInputs)
138158

139159
if (childMapOutputStatistics.length > 0 && supportAdaptive) {
140160
val exchangeCoordinator = new ExchangeCoordinator(

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,4 +1001,29 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
10011001
" union all select count(test.age) from test"),
10021002
Row(1) :: Row(1) :: Row(2) :: Nil)
10031003
}
1004+
1005+
test("different pre-shuffle partition number of datasets to union with adaptive") {
1006+
val sparkSession = defaultSparkSession
1007+
val dataset1 = sparkSession.range(1000)
1008+
val dataset2 = sparkSession.range(1001)
1009+
1010+
val compute = dataset1.repartition(505, dataset1.col("id"))
1011+
.union(dataset2.repartition(105, dataset2.col("id")))
1012+
1013+
assert(compute.orderBy("id").toDF("id").takeAsList(10).toArray
1014+
=== Seq((0), (0), (1), (1), (2), (2), (3), (3), (4), (4)).map(i => Row(i)).toArray)
1015+
compute.explain()
1016+
}
1017+
1018+
test("different pre-shuffle partition number of datasets to join with adaptive") {
1019+
val sparkSession = defaultSparkSession
1020+
val dataset1 = sparkSession.range(1000)
1021+
val dataset2 = sparkSession.range(1001)
1022+
val compute = dataset1.repartition(105).toDF("key1")
1023+
.join(dataset1.repartition(505).toDF("key2"), col("key1") === col("key2"), "left")
1024+
assert(compute.orderBy("key1").toDF("key1","key2").select("key1").takeAsList(10).toArray
1025+
=== Seq((0), (1), (2), (3), (4), (5), (6), (7), (8), (9)).map(i => Row(i)).toArray)
1026+
compute.explain()
1027+
}
1028+
10041029
}

0 commit comments

Comments
 (0)