File tree Expand file tree Collapse file tree 3 files changed +27
-4
lines changed
main/scala/org/apache/spark/sql/catalyst/optimizer
test/scala/org/apache/spark/sql/catalyst/optimizer Expand file tree Collapse file tree 3 files changed +27
-4
lines changed Original file line number Diff line number Diff line change @@ -358,7 +358,7 @@ object EliminateDistinct extends Rule[LogicalPlan] {
358358 ae.copy(isDistinct = false )
359359 }
360360
361- private def isDuplicateAgnostic (af : AggregateFunction ): Boolean = af match {
361+ def isDuplicateAgnostic (af : AggregateFunction ): Boolean = af match {
362362 case _ : Max => true
363363 case _ : Min => true
364364 case _ : BitAndAgg => true
Original file line number Diff line number Diff line change @@ -48,7 +48,9 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
4848 }
4949
5050 private def isLowerRedundant (upper : Aggregate , lower : Aggregate ): Boolean = {
51- val upperHasNoAggregateExpressions = ! upper.aggregateExpressions.exists(isAggregate)
51+ val upperHasNoDuplicateSensitiveAgg = ! upper
52+ .aggregateExpressions
53+ .exists(isDuplicateSensitiveAggregate)
5254
5355 lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet (
5456 lower
@@ -58,11 +60,18 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
5860 .map(_.toAttribute)
5961 ))
6062
61- upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
63+ upperHasNoDuplicateSensitiveAgg && upperRefsOnlyDeterministicNonAgg
6264 }
6365
6466 private def isAggregate (expr : Expression ): Boolean = {
6567 expr.find(e => e.isInstanceOf [AggregateExpression ] ||
6668 PythonUDF .isGroupedAggPandasUDF(e)).isDefined
6769 }
70+
71+ private def isDuplicateSensitiveAggregate (expr : Expression ): Boolean = {
72+ expr.find {
73+ case ae : AggregateExpression => ! EliminateDistinct .isDuplicateAgnostic(ae.aggregateFunction)
74+ case e => PythonUDF .isGroupedAggPandasUDF(e)
75+ }.isDefined
76+ }
6877}
Original file line number Diff line number Diff line change @@ -126,7 +126,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
126126 comparePlans(optimized, expected)
127127 }
128128
129- test(" Keep non-redundant aggregate - upper has agg expression" ) {
129+ test(" Keep non-redundant aggregate - upper has duplicate sensitive agg expression" ) {
130130 val relation = LocalRelation (' a .int, ' b .int)
131131 for (agg <- aggregates(' b )) {
132132 val query = relation
@@ -139,6 +139,20 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
139139 }
140140 }
141141
142+ test(" Remove redundant aggregate - upper has duplicate agnostic agg expression" ) {
143+ val relation = LocalRelation (' a .int, ' b .int)
144+ val query = relation
145+ .groupBy(' a , ' b )(' a , ' b )
146+ // The max does not change if there are duplicate values
147+ .groupBy(' a )(' a , max(' b ))
148+ .analyze
149+ val expected = relation
150+ .groupBy(' a )(' a , max(' b ))
151+ .analyze
152+ val optimized = Optimize .execute(query)
153+ comparePlans(optimized, expected)
154+ }
155+
142156 test(" Keep non-redundant aggregate - upper references agg expression" ) {
143157 val relation = LocalRelation (' a .int, ' b .int)
144158 for (agg <- aggregates(' b )) {
You can’t perform that action at this time.
0 commit comments