From 7f4653246647b32299b0f48414391323fcd4273a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 14 Nov 2014 00:34:09 +0800 Subject: [PATCH 1/3] Enables struct fields as sub expressions of grouping fields --- .../sql/catalyst/analysis/Analyzer.scala | 24 +++++++------------ .../sql/catalyst/planning/patterns.scala | 8 +++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 +++++++++- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a448c794213ae..dc54e732066fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimAliases :: + TrimGroupingAliases :: typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, @@ -70,6 +70,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool EliminateAnalysisOperators) ) + private def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c } + /** * Makes sure all attributes and logical plans have been resolved. */ @@ -93,17 +95,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool /** * Removes no-op Alias expressions from the plan. */ - object TrimAliases extends Rule[LogicalPlan] { + object TrimGroupingAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Aggregate(groups, aggs, child) => - Aggregate( - groups.map { - _ transform { - case Alias(c, _) => c - } - }, - aggs, - child) + Aggregate(groups.map(trimAliases), aggs, child) } } @@ -122,10 +117,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case e => e.children.forall(isValidAggregateExpression) } - aggregateExprs.foreach { e => - if (!isValidAggregateExpression(e)) { - throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") - } + aggregateExprs.find { e => + !isValidAggregateExpression(trimAliases(e)) + }.foreach { e => + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") } aggregatePlan @@ -328,4 +323,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] { case Subquery(_, child) => child } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f0fd9a8b9a46e..be38b5c7f3a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -146,6 +146,8 @@ object PartialAggregation { case other => (other, Alias(other, "PartialGroup")()) }.toMap + def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c } + // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { @@ -153,6 +155,8 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression if namedGroupingExpressions.contains(e) => namedGroupingExpressions(e).toAttribute + case e: Expression if namedGroupingExpressions.contains(trimAliases(e)) => + namedGroupingExpressions(trimAliases(e)).toAttribute }).asInstanceOf[Seq[NamedExpression]] val partialComputation = @@ -188,7 +192,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. - val (joinPredicates, otherPredicates) = + val (joinPredicates, otherPredicates) = condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || (canEvaluate(l, right) && canEvaluate(r, left)) => true @@ -203,7 +207,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8a80724c08c7c..996f577485132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -544,7 +544,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) } - test("INTERSECT") { + test("INTERSECT") { checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), (1, "a") :: @@ -942,4 +942,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), (1 to 99).map(i => Seq(i))) } + + test("SPARK-4322 Grouping field with struct field as sub expression") { + jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1) + dropTempTable("data") + + jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2) + dropTempTable("data") + } } From dd20a797f17e0901ecc4bafe296cf2f7e568d1cb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 14 Nov 2014 18:42:42 +0800 Subject: [PATCH 2/3] Should only trim aliases around `GetField`s --- .../spark/sql/catalyst/analysis/Analyzer.scala | 11 +++++++---- .../spark/sql/catalyst/planning/patterns.scala | 16 +++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dc54e732066fd..d3b4cf8e34242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -70,8 +70,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool EliminateAnalysisOperators) ) - private def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c } - /** * Makes sure all attributes and logical plans have been resolved. */ @@ -98,7 +96,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object TrimGroupingAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Aggregate(groups, aggs, child) => - Aggregate(groups.map(trimAliases), aggs, child) + Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) } } @@ -118,7 +116,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } aggregateExprs.find { e => - !isValidAggregateExpression(trimAliases(e)) + !isValidAggregateExpression(e.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g: GetField, _) => g + }) }.foreach { e => throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index be38b5c7f3a91..7eb7f29626c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -146,17 +146,23 @@ object PartialAggregation { case other => (other, Alias(other, "PartialGroup")()) }.toMap - def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c } + def trimGetFieldAliases(e: Expression) = e.transform { case Alias(g: GetField, _) => g } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - case e: Expression if namedGroupingExpressions.contains(trimAliases(e)) => - namedGroupingExpressions(trimAliases(e)).toAttribute + + case e: Expression => + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + namedGroupingExpressions + .get(e) + .orElse(namedGroupingExpressions.get(trimGetFieldAliases(e))) + .map(_.toAttribute) + .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = From 23a46ea1f9e623bc635849fa3ef469d64f5473e1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 14 Nov 2014 19:10:14 +0800 Subject: [PATCH 3/3] Code simplification --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7eb7f29626c35..310d127506d68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -146,8 +146,6 @@ object PartialAggregation { case other => (other, Alias(other, "PartialGroup")()) }.toMap - def trimGetFieldAliases(e: Expression) = e.transform { case Alias(g: GetField, _) => g } - // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { @@ -159,8 +157,7 @@ object PartialAggregation { // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) namedGroupingExpressions - .get(e) - .orElse(namedGroupingExpressions.get(trimGetFieldAliases(e))) + .get(e.transform { case Alias(g: GetField, _) => g }) .map(_.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]]