diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c5c2a5b23667..7538a6477f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1003,18 +1003,32 @@ class Analyzer( */ object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + // This is a strict check though, we put this to apply the rule only if the expression is not + // resolvable by child. + private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = { + !child.output.exists(a => resolver(a.name, attrName)) + } + + private def mayResolveAttrByAggregateExprs( + exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = { + exprs.map { _.transform { + case u: UnresolvedAttribute if notResolvableByChild(u.name, child) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + }} + } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => + agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) + + case gs @ GroupingSets(selectedGroups, groups, child, aggs) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedAttribute]) => - // This is a strict check though, we put this to apply the rule only in alias expressions - def notResolvableByChild(attrName: String): Boolean = - !child.output.exists(a => resolver(a.name, attrName)) - agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute if notResolvableByChild(u.name) => - aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) - case e => e - }) + gs.copy( + selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)), + groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f663d7b8a8f7..2c19265bedc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -704,7 +704,7 @@ case class Expand( * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer * * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should - * exists in groupByExprs. + * exist in groupByExprs. * @param groupByExprs The Group By expressions candidates. * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index f8135389a9e5..8aff4cb52419 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; \ No newline at end of file +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; + +-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b); +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index 825e8f5488c8..ce7a16a4d0c8 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -328,3 +328,50 @@ struct<> -- !query 25 output org.apache.spark.sql.AnalysisException grouping__id is deprecated; use grouping_id() instead; + + +-- !query 26 +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2) +-- !query 26 schema +struct +-- !query 26 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL 1 3 +NULL 2 0 +NULL NULL 3 + + +-- !query 27 +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b) +-- !query 27 schema +struct +-- !query 27 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL NULL 3 + + +-- !query 28 +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) +-- !query 28 schema +struct<(a + b):int,k:int,sum((a - b)):bigint> +-- !query 28 output +NULL 1 3 +NULL 2 0