Skip to content

Commit 1cac300

Browse files
lianchengmarmbrus
authored andcommitted
[SPARK-4322][SQL] Enables struct fields as sub expressions of grouping fields
While resolving struct fields, the resulted `GetField` expression is wrapped with an `Alias` to make it a named expression. Assume `a` is a struct instance with a field `b`, then `"a.b"` will be resolved as `Alias(GetField(a, "b"), "b")`. Thus, for this following SQL query: ```sql SELECT a.b + 1 FROM t GROUP BY a.b + 1 ``` the grouping expression is ```scala Add(GetField(a, "b"), Literal(1, IntegerType)) ``` while the aggregation expression is ```scala Add(Alias(GetField(a, "b"), "b"), Literal(1, IntegerType)) ``` This mismatch makes the above SQL query fail during the both analysis and execution phases. This PR fixes this issue by removing the alias when substituting aggregation expressions. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3248) <!-- Reviewable:end --> Author: Cheng Lian <[email protected]> Closes #3248 from liancheng/spark-4322 and squashes the following commits: 23a46ea [Cheng Lian] Code simplification dd20a79 [Cheng Lian] Should only trim aliases around `GetField`s 7f46532 [Cheng Lian] Enables struct fields as sub expressions of grouping fields (cherry picked from commit 0c7b66b) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 680bc06 commit 1cac300

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
6060
ResolveFunctions ::
6161
GlobalAggregates ::
6262
UnresolvedHavingClauseAttributes ::
63-
TrimAliases ::
63+
TrimGroupingAliases ::
6464
typeCoercionRules ++
6565
extendedRules : _*),
6666
Batch("Check Analysis", Once,
@@ -93,17 +93,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
9393
/**
9494
* Removes no-op Alias expressions from the plan.
9595
*/
96-
object TrimAliases extends Rule[LogicalPlan] {
96+
object TrimGroupingAliases extends Rule[LogicalPlan] {
9797
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
9898
case Aggregate(groups, aggs, child) =>
99-
Aggregate(
100-
groups.map {
101-
_ transform {
102-
case Alias(c, _) => c
103-
}
104-
},
105-
aggs,
106-
child)
99+
Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
107100
}
108101
}
109102

@@ -122,10 +115,15 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
122115
case e => e.children.forall(isValidAggregateExpression)
123116
}
124117

125-
aggregateExprs.foreach { e =>
126-
if (!isValidAggregateExpression(e)) {
127-
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
128-
}
118+
aggregateExprs.find { e =>
119+
!isValidAggregateExpression(e.transform {
120+
// Should trim aliases around `GetField`s. These aliases are introduced while
121+
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
122+
// (Should we just turn `GetField` into a `NamedExpression`?)
123+
case Alias(g: GetField, _) => g
124+
})
125+
}.foreach { e =>
126+
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
129127
}
130128

131129
aggregatePlan
@@ -328,4 +326,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] {
328326
case Subquery(_, child) => child
329327
}
330328
}
331-

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,15 @@ object PartialAggregation {
151151
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
152152
case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
153153
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
154-
case e: Expression if namedGroupingExpressions.contains(e) =>
155-
namedGroupingExpressions(e).toAttribute
154+
155+
case e: Expression =>
156+
// Should trim aliases around `GetField`s. These aliases are introduced while
157+
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
158+
// (Should we just turn `GetField` into a `NamedExpression`?)
159+
namedGroupingExpressions
160+
.get(e.transform { case Alias(g: GetField, _) => g })
161+
.map(_.toAttribute)
162+
.getOrElse(e)
156163
}).asInstanceOf[Seq[NamedExpression]]
157164

158165
val partialComputation =
@@ -188,7 +195,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
188195
logDebug(s"Considering join on: $condition")
189196
// Find equi-join predicates that can be evaluated before the join, and thus can be used
190197
// as join keys.
191-
val (joinPredicates, otherPredicates) =
198+
val (joinPredicates, otherPredicates) =
192199
condition.map(splitConjunctivePredicates).getOrElse(Nil).partition {
193200
case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
194201
(canEvaluate(l, right) && canEvaluate(r, left)) => true
@@ -203,7 +210,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
203210
val rightKeys = joinKeys.map(_._2)
204211

205212
if (joinKeys.nonEmpty) {
206-
logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
213+
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
207214
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
208215
} else {
209216
None

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
551551
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
552552
}
553553

554-
test("INTERSECT") {
554+
test("INTERSECT") {
555555
checkAnswer(
556556
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"),
557557
(1, "a") ::
@@ -949,4 +949,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
949949
checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"),
950950
(1 to 99).map(i => Seq(i)))
951951
}
952+
953+
test("SPARK-4322 Grouping field with struct field as sub expression") {
954+
jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
955+
checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1)
956+
dropTempTable("data")
957+
958+
jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
959+
checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2)
960+
dropTempTable("data")
961+
}
952962
}

0 commit comments

Comments
 (0)