Skip to content

Commit 35b0520

Browse files
committed
Use semanticEquals to replace grouping expressions in the output of the aggregate operator.
1 parent 3b43b24 commit 35b0520

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
2525
import org.apache.spark.sql.execution.SparkPlan
2626
import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
2727

28+
/**
29+
* Utility functions used by the query planner to convert our plan to new aggregation code path.
30+
*/
2831
object Utils {
2932
// Right now, we do not support complex types in the grouping key schema.
3033
private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
@@ -214,11 +217,15 @@ object Utils {
214217
expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
215218
}
216219
val rewrittenResultExpressions = resultExpressions.map { expr =>
217-
expr.transform {
220+
expr.transformDown {
218221
case agg: AggregateExpression2 =>
219222
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
220-
case expression if groupExpressionMap.contains(expression) =>
221-
groupExpressionMap(expression).toAttribute
223+
case expression =>
224+
// We do not rely on the equality check at here since attributes may
225+
// different cosmetically. Instead, we use semanticEquals.
226+
groupExpressionMap.collectFirst {
227+
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
228+
}.getOrElse(expression)
222229
}.asInstanceOf[NamedExpression]
223230
}
224231
val finalAggregate = Aggregate2Sort(
@@ -334,8 +341,12 @@ object Utils {
334341
expr.transform {
335342
case agg: AggregateExpression2 =>
336343
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
337-
case expression if groupExpressionMap.contains(expression) =>
338-
groupExpressionMap(expression).toAttribute
344+
case expression =>
345+
// We do not rely on the equality check at here since attributes may
346+
// different cosmetically. Instead, we use semanticEquals.
347+
groupExpressionMap.collectFirst {
348+
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
349+
}.getOrElse(expression)
339350
}.asInstanceOf[NamedExpression]
340351
}
341352
val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,40 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
187187
Row(null, null) :: Nil)
188188
}
189189

190+
test("case in-sensitive resolution") {
191+
checkAnswer(
192+
sqlContext.sql(
193+
"""
194+
|SELECT avg(value), kEY - 100
195+
|FROM agg1
196+
|GROUP BY Key - 100
197+
""".stripMargin),
198+
Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil)
199+
200+
checkAnswer(
201+
sqlContext.sql(
202+
"""
203+
|SELECT sum(distinct value1), kEY - 100, count(distinct value1)
204+
|FROM agg2
205+
|GROUP BY Key - 100
206+
""".stripMargin),
207+
Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil)
208+
209+
checkAnswer(
210+
sqlContext.sql(
211+
"""
212+
|SELECT valUe * key - 100
213+
|FROM agg1
214+
|GROUP BY vAlue * keY - 100
215+
""".stripMargin),
216+
Row(-90) ::
217+
Row(-80) ::
218+
Row(-70) ::
219+
Row(-100) ::
220+
Row(-102) ::
221+
Row(null) :: Nil)
222+
}
223+
190224
test("test average no key in output") {
191225
checkAnswer(
192226
sqlContext.sql(

0 commit comments

Comments
 (0)