Skip to content

Commit 1e46f92

Browse files
maropucloud-fan
authored andcommitted
[SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set
## What changes were proposed in this pull request? This pr fixed an issue when having multiple distinct aggregations having the same argument set, e.g., ``` scala>: paste val df = sql( s"""SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) | FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) """.stripMargin) java.lang.RuntimeException You hit a query analyzer bug. Please report your query to Spark user mailing list. ``` The root cause is that `RewriteDistinctAggregates` can't detect multiple distinct aggregations if they have the same argument set. This pr modified code so that `RewriteDistinctAggregates` could count the number of aggregate expressions with `isDistinct=true`. ## How was this patch tested? Added tests in `DataFrameAggregateSuite`. Author: Takeshi Yamamuro <[email protected]> Closes #21443 from maropu/SPARK-24369.
1 parent 9e7bad0 commit 1e46f92

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
115115
}
116116

117117
// Extract distinct aggregate expressions.
118-
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
118+
val distincgAggExpressions = aggExpressions.filter(_.isDistinct)
119+
val distinctAggGroups = distincgAggExpressions.groupBy { e =>
119120
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
120121
if (unfoldableChildren.nonEmpty) {
121122
// Only expand the unfoldable children
@@ -132,7 +133,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
132133
}
133134

134135
// Aggregation strategy can handle queries with a single distinct group.
135-
if (distinctAggGroups.size > 1) {
136+
if (distincgAggExpressions.size > 1) {
136137
// Create the attributes for the grouping id and the group by clause.
137138
val gid = AttributeReference("gid", IntegerType, nullable = false)()
138139
val groupByMap = a.groupingExpressions.collect {
@@ -151,7 +152,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
151152
}
152153

153154
// Setup unique distinct aggregate children.
154-
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
155+
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
155156
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
156157
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
157158

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
386386
aggregateExpressions.partition(_.isDistinct)
387387
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
388388
// This is a sanity check. We should not reach here when we have multiple distinct
389-
// column sets. Our MultipleDistinctRewriter should take care this case.
389+
// column sets. Our `RewriteDistinctAggregates` should take care this case.
390390
sys.error("You hit a query analyzer bug. Please report your query to " +
391391
"Spark user mailing list.")
392392
}

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,8 @@ SELECT 1 from (
6868
FROM (select 1 as x) a
6969
WHERE false
7070
) b
71-
where b.z != b.z
71+
where b.z != b.z;
72+
73+
-- SPARK-24369 multiple distinct aggregations having the same argument set
74+
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
75+
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y);

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 26
2+
-- Number of queries: 27
33

44

55
-- !query 0
@@ -241,3 +241,12 @@ where b.z != b.z
241241
struct<1:int>
242242
-- !query 25 output
243243

244+
245+
246+
-- !query 26
247+
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
248+
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y)
249+
-- !query 26 schema
250+
struct<corr(DISTINCT CAST(x AS DOUBLE), CAST(y AS DOUBLE)):double,corr(DISTINCT CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,count(1):bigint>
251+
-- !query 26 output
252+
1.0 1.0 3

0 commit comments

Comments
 (0)