Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1003,18 +1003,32 @@ class Analyzer(
*/
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {

// This is a strict check though, we put this to apply the rule only if the expression is not
// resolvable by child.
private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = {
!child.output.exists(a => resolver(a.name, attrName))
}

private def mayResolveAttrByAggregateExprs(
exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = {
exprs.map { _.transform {
case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
}}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(!_.resolved) =>
agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child))

case gs @ GroupingSets(selectedGroups, groups, child, aggs)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
// This is a strict check though, we put this to apply the rule only in alias expressions
def notResolvableByChild(attrName: String): Boolean =
!child.output.exists(a => resolver(a.name, attrName))
agg.copy(groupingExpressions = groups.map {
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
case e => e
})
gs.copy(
selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)),
groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ case class Expand(
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
*
* @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should
* exists in groupByExprs.
* exist in groupByExprs.
* @param groupByExprs The Group By expressions candidates.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co
ORDER BY GROUPING(course), GROUPING(year), course, year;
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course);
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course);
SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;
SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;

-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS
SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2);
SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b);
SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 26
-- Number of queries: 29


-- !query 0
Expand Down Expand Up @@ -328,3 +328,50 @@ struct<>
-- !query 25 output
org.apache.spark.sql.AnalysisException
grouping__id is deprecated; use grouping_id() instead;


-- !query 26
SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2)
-- !query 26 schema
struct<k1:int,k2:int,sum((a - b)):bigint>
-- !query 26 output
2 1 0
2 NULL 0
3 1 1
3 2 -1
3 NULL 0
4 1 2
4 2 0
4 NULL 2
5 2 1
5 NULL 1
NULL 1 3
NULL 2 0
NULL NULL 3


-- !query 27
SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b)
-- !query 27 schema
struct<k:int,b:int,sum((a - b)):bigint>
-- !query 27 output
2 1 0
2 NULL 0
3 1 1
3 2 -1
3 NULL 0
4 1 2
4 2 0
4 NULL 2
5 2 1
5 NULL 1
NULL NULL 3


-- !query 28
SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
-- !query 28 schema
struct<(a + b):int,k:int,sum((a - b)):bigint>
-- !query 28 output
NULL 1 3
NULL 2 0