From 479df75e0bc509e525266d68934ead1fcf4f685c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Sep 2015 15:26:10 +0800 Subject: [PATCH 1/2] support order by non-attribute grouping expression on Aggregate --- .../sql/catalyst/analysis/Analyzer.scala | 52 +++++++------------ .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +++++-- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1a5de15c61f86..ee02d501d10f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -560,43 +560,29 @@ class Analyzer( filter } - case sort @ Sort(sortOrder, global, - aggregate @ Aggregate(grouping, originalAggExprs, child)) + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved && !sort.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child) + val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions - - // Expressions that have an aggregate can be pushed down. - val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate) - - // Attribute references, that are missing from the order but are present in the grouping - // expressions can also be pushed down. - val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _) - val missingAttributes = requiredAttributes -- aggregate.outputSet - val validPushdownAttributes = - missingAttributes.filter(a => grouping.exists(a.semanticEquals)) - - // If resolution was successful and we see the ordering either has an aggregate in it or - // it is missing something that is projected away by the aggregate, add the ordering - // the original aggregate operator. - if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) { - val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map { - case (order, evaluated) => order.copy(child = evaluated.toAttribute) - } - val aggExprsWithOrdering: Seq[NamedExpression] = - resolvedAggregateOrdering ++ originalAggExprs - - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) - } else { - sort + val resolvedOrdering = resolvedOperator.aggregateExpressions + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedOperator) + // todo: some ordering expressions can be evaluated with existing aggregate expressions + // and we don't need to push them down to Aggregate. + val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedOrdering).map { + case (order, evaluated) => order.copy(child = evaluated.toAttribute) } + val aggExprsWithOrdering = aggregate.aggregateExpressions ++ resolvedOrdering + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. @@ -605,9 +591,7 @@ class Analyzer( } protected def containsAggregate(condition: Expression): Boolean = { - condition - .collect { case ae: AggregateExpression => ae } - .nonEmpty + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e172b2c264cb..c7f6b99ff0b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1712,9 +1712,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10130 type coercion for IF should have children resolved first") { - val df = Seq((1, 1), (-1, 1)).toDF("key", "value") - df.registerTempTable("src") - checkAnswer( - sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } } } From 806c713c63df82e2e9941a87d3f2a94852c5e7e6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Sep 2015 17:28:48 +0800 Subject: [PATCH 2/2] finish the todo --- .../sql/catalyst/analysis/Analyzer.scala | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ee02d501d10f0..591747b45c376 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -567,22 +567,40 @@ class Analyzer( try { val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - val resolvedOrdering = resolvedOperator.aggregateExpressions + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] // If we pass the analysis check, then the ordering expressions should only reference to // aggregate expressions or grouping expressions, and it's safe to push them down to // Aggregate. - checkAnalysis(resolvedOperator) - // todo: some ordering expressions can be evaluated with existing aggregate expressions - // and we don't need to push them down to Aggregate. - val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedOrdering).map { - case (order, evaluated) => order.copy(child = evaluated.toAttribute) + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } + + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } } - val aggExprsWithOrdering = aggregate.aggregateExpressions ++ resolvedOrdering + Project(aggregate.output, Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan.