Skip to content

Commit 8038a36

Browse files
committed
handle aggs too
1 parent 0b9c687 commit 8038a36

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,13 @@ class Analyzer(catalog: Catalog,
247247
(oldVersion, newVersion)
248248

249249
// Handle projects that create conflicting aliases.
250-
case oldVersion @ Project(projectList, child)
251-
if newAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
252-
val newVersion =
253-
oldVersion.copy(
254-
projectList = projectList.map {
255-
case a: Alias => Alias(a.child, a.name)()
256-
case other => other
257-
})
258-
(oldVersion, newVersion)
250+
case oldVersion @ Project(projectList, _)
251+
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
252+
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
253+
254+
case oldVersion @ Aggregate(_, aggregateExpressions, _)
255+
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
256+
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
259257
}.head // Only handle first case found, others will be fixed on the next pass.
260258

261259
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
@@ -285,7 +283,14 @@ class Analyzer(catalog: Catalog,
285283
}
286284
}
287285

288-
def newAliases(projectList: Seq[NamedExpression]): AttributeSet = {
286+
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
287+
expressions.map {
288+
case a: Alias => Alias(a.child, a.name)()
289+
case other => other
290+
}
291+
}
292+
293+
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
289294
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
290295
}
291296

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
3737
val sqlCtx = TestSQLContext
3838

3939
test("self join with aliases") {
40-
val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df")
40+
Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df")
4141

4242
checkAnswer(
4343
sql(
@@ -49,6 +49,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
4949
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
5050
}
5151

52+
test("self join with alias in agg") {
53+
Seq(1,2,3)
54+
.map(i => (i, i.toString))
55+
.toDF("int", "str")
56+
.groupBy("str")
57+
.agg($"str", count("str").as("strCount"))
58+
.registerTempTable("df")
59+
60+
checkAnswer(
61+
sql(
62+
"""
63+
|SELECT x.str, SUM(x.strCount)
64+
|FROM df x JOIN df y ON x.str = y.str
65+
|GROUP BY x.str
66+
""".stripMargin),
67+
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
68+
}
69+
5270
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
5371
checkAnswer(
5472
sql("SELECT a FROM testData2 SORT BY a"),

0 commit comments

Comments
 (0)