@@ -576,43 +576,47 @@ class Analyzer(
576576 filter
577577 }
578578
579- case sort @ Sort (sortOrder, global,
580- aggregate @ Aggregate (grouping, originalAggExprs, child))
579+ case sort @ Sort (sortOrder, global, aggregate : Aggregate )
581580 if aggregate.resolved && ! sort.resolved =>
582581
583582 // Try resolving the ordering as though it is in the aggregate clause.
584583 try {
585- val aliasedOrder = sortOrder.map(o => Alias (o.child, " aggOrder" )())
586- val aggregatedOrdering = Aggregate (grouping, aliasedOrder, child)
587- val resolvedOperator : Aggregate = execute(aggregatedOrdering).asInstanceOf [Aggregate ]
588- def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
589-
590- // Expressions that have an aggregate can be pushed down.
591- val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
592-
593- // Attribute references, that are missing from the order but are present in the grouping
594- // expressions can also be pushed down.
595- val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
596- val missingAttributes = requiredAttributes -- aggregate.outputSet
597- val validPushdownAttributes =
598- missingAttributes.filter(a => grouping.exists(a.semanticEquals))
599-
600- // If resolution was successful and we see the ordering either has an aggregate in it or
601- // it is missing something that is projected away by the aggregate, add the ordering
602- // the original aggregate operator.
603- if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
604- val evaluatedOrderings : Seq [SortOrder ] = sortOrder.zip(resolvedAggregateOrdering).map {
605- case (order, evaluated) => order.copy(child = evaluated.toAttribute)
606- }
607- val aggExprsWithOrdering : Seq [NamedExpression ] =
608- resolvedAggregateOrdering ++ originalAggExprs
609-
610- Project (aggregate.output,
611- Sort (evaluatedOrderings, global,
612- aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
613- } else {
614- sort
584+ val aliasedOrdering = sortOrder.map(o => Alias (o.child, " aggOrder" )())
585+ val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
586+ val resolvedAggregate : Aggregate = execute(aggregatedOrdering).asInstanceOf [Aggregate ]
587+ val resolvedAliasedOrdering : Seq [Alias ] =
588+ resolvedAggregate.aggregateExpressions.asInstanceOf [Seq [Alias ]]
589+
590+ // If we pass the analysis check, then the ordering expressions should only reference to
591+ // aggregate expressions or grouping expressions, and it's safe to push them down to
592+ // Aggregate.
593+ checkAnalysis(resolvedAggregate)
594+
595+ val originalAggExprs = aggregate.aggregateExpressions.map(
596+ CleanupAliases .trimNonTopLevelAliases(_).asInstanceOf [NamedExpression ])
597+
598+ // If the ordering expression is same with original aggregate expression, we don't need
599+ // to push down this ordering expression and can reference the original aggregate
600+ // expression instead.
601+ val needsPushDown = ArrayBuffer .empty[NamedExpression ]
602+ val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
603+ case (evaluated, order) =>
604+ val index = originalAggExprs.indexWhere {
605+ case Alias (child, _) => child semanticEquals evaluated.child
606+ case other => other semanticEquals evaluated.child
607+ }
608+
609+ if (index == - 1 ) {
610+ needsPushDown += evaluated
611+ order.copy(child = evaluated.toAttribute)
612+ } else {
613+ order.copy(child = originalAggExprs(index).toAttribute)
614+ }
615615 }
616+
617+ Project (aggregate.output,
618+ Sort (evaluatedOrderings, global,
619+ aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
616620 } catch {
617621 // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
618622 // just return the original plan.
@@ -621,9 +625,7 @@ class Analyzer(
621625 }
622626
623627 protected def containsAggregate (condition : Expression ): Boolean = {
624- condition
625- .collect { case ae : AggregateExpression => ae }
626- .nonEmpty
628+ condition.find(_.isInstanceOf [AggregateExpression ]).isDefined
627629 }
628630 }
629631
0 commit comments