@@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
2525import org .apache .spark .sql .execution .SparkPlan
2626import org .apache .spark .sql .types .{StructType , MapType , ArrayType }
2727
28+ /**
29+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
30+ */
2831object Utils {
2932 // Right now, we do not support complex types in the grouping key schema.
3033 private def supportsGroupingKeySchema (aggregate : Aggregate ): Boolean = {
@@ -214,11 +217,15 @@ object Utils {
214217 expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
215218 }
216219 val rewrittenResultExpressions = resultExpressions.map { expr =>
217- expr.transform {
220+ expr.transformDown {
218221 case agg : AggregateExpression2 =>
219222 aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
220- case expression if groupExpressionMap.contains(expression) =>
221- groupExpressionMap(expression).toAttribute
223+ case expression =>
224+ // We do not rely on the equality check at here since attributes may
225+ // different cosmetically. Instead, we use semanticEquals.
226+ groupExpressionMap.collectFirst {
227+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
228+ }.getOrElse(expression)
222229 }.asInstanceOf [NamedExpression ]
223230 }
224231 val finalAggregate = Aggregate2Sort (
@@ -334,8 +341,12 @@ object Utils {
334341 expr.transform {
335342 case agg : AggregateExpression2 =>
336343 aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
337- case expression if groupExpressionMap.contains(expression) =>
338- groupExpressionMap(expression).toAttribute
344+ case expression =>
345+ // We do not rely on the equality check at here since attributes may
346+ // different cosmetically. Instead, we use semanticEquals.
347+ groupExpressionMap.collectFirst {
348+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
349+ }.getOrElse(expression)
339350 }.asInstanceOf [NamedExpression ]
340351 }
341352 val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort (
0 commit comments