From 71249cfeb9e798e41d8ef0f7423b48685a1774e7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 29 Apr 2016 16:17:05 -0700 Subject: [PATCH] [SPARK-15020][SQL] GROUP-BY should support Aliases ``` sql("select a x from values 1 T(a) group by x").explain org.apache.spark.sql.AnalysisException: cannot resolve '`x`' given input columns: [a]; line 1 pos 39 ``` --- .../spark/sql/catalyst/analysis/Analyzer.scala | 18 ++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 6 ++++++ 2 files changed, 24 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f6a65f7e6c09..400b06540feb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -514,6 +514,13 @@ class Analyzer( } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } + + // When resolve grouping expressions in Aggregate, we need to consider aliases in + // aggregationExprs; e.g. SELECT 1 a GROUP BY a + case a @ Aggregate(ge, ae, child) if child.resolved && ae.forall(_.resolved) && !a.resolved => + val newGroupingExprs = ge.map(expandGroupingExpr(_, ae)) + a.copy(groupingExpressions = newGroupingExprs) + // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( @@ -633,6 +640,17 @@ class Analyzer( failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") } } + + /** + * Expands the matching attribute in aggregation expression aliases. + */ + private def expandGroupingExpr(expr: Expression, aggregationExprs: Seq[NamedExpression]) = + if (!expr.resolved && expr.isInstanceOf[UnresolvedAttribute]) { + val name = expr.asInstanceOf[UnresolvedAttribute].name + aggregationExprs.filter(x => x.resolved && x.name == name).headOption.getOrElse(expr) + } else { + expr + } } private def containsDeserializer(exprs: Seq[Expression]): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a63d1770f325..ddf2ba8592e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -346,4 +346,10 @@ class AnalysisSuite extends AnalysisTest { assertAnalysisSuccess(query) } + + test("SPARK-15020: GROUP-BY should support Aliases") { + val input = LocalRelation('a.int) + val query = input.groupBy('x)('a.as('x)) + assertAnalysisSuccess(query) + } }