Skip to content

Commit 19748dc

Browse files
committed
Allow duplicate angostic aggs
1 parent fba05c7 commit 19748dc

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ object EliminateDistinct extends Rule[LogicalPlan] {
358358
ae.copy(isDistinct = false)
359359
}
360360

361-
private def isDuplicateAgnostic(af: AggregateFunction): Boolean = af match {
361+
def isDuplicateAgnostic(af: AggregateFunction): Boolean = af match {
362362
case _: Max => true
363363
case _: Min => true
364364
case _: BitAndAgg => true

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
4848
}
4949

5050
private def isLowerRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
51-
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
51+
val upperHasNoDuplicateSensitiveAgg = !upper
52+
.aggregateExpressions
53+
.exists(isDuplicateSensitiveAggregate)
5254

5355
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
5456
lower
@@ -58,11 +60,18 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
5860
.map(_.toAttribute)
5961
))
6062

61-
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
63+
upperHasNoDuplicateSensitiveAgg && upperRefsOnlyDeterministicNonAgg
6264
}
6365

6466
private def isAggregate(expr: Expression): Boolean = {
6567
expr.find(e => e.isInstanceOf[AggregateExpression] ||
6668
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
6769
}
70+
71+
private def isDuplicateSensitiveAggregate(expr: Expression): Boolean = {
72+
expr.find {
73+
case ae: AggregateExpression => !EliminateDistinct.isDuplicateAgnostic(ae.aggregateFunction)
74+
case e => PythonUDF.isGroupedAggPandasUDF(e)
75+
}.isDefined
76+
}
6877
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
126126
comparePlans(optimized, expected)
127127
}
128128

129-
test("Keep non-redundant aggregate - upper has agg expression") {
129+
test("Keep non-redundant aggregate - upper has duplicate sensitive agg expression") {
130130
val relation = LocalRelation('a.int, 'b.int)
131131
for (agg <- aggregates('b)) {
132132
val query = relation
@@ -139,6 +139,20 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
139139
}
140140
}
141141

142+
test("Remove redundant aggregate - upper has duplicate agnostic agg expression") {
143+
val relation = LocalRelation('a.int, 'b.int)
144+
val query = relation
145+
.groupBy('a, 'b)('a, 'b)
146+
// The max does not change if there are duplicate values
147+
.groupBy('a)('a, max('b))
148+
.analyze
149+
val expected = relation
150+
.groupBy('a)('a, max('b))
151+
.analyze
152+
val optimized = Optimize.execute(query)
153+
comparePlans(optimized, expected)
154+
}
155+
142156
test("Keep non-redundant aggregate - upper references agg expression") {
143157
val relation = LocalRelation('a.int, 'b.int)
144158
for (agg <- aggregates('b)) {

0 commit comments

Comments
 (0)