From 5456a9566db038ec38ebb4f620e102b18365c342 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Thu, 13 Oct 2016 22:01:26 +0800 Subject: [PATCH 1/9] remove bitmasks for grouping sets and use actual sets instead. --- .../sql/catalyst/analysis/Analyzer.scala | 60 ++++++++++++----- .../sql/catalyst/parser/AstBuilder.scala | 20 ++---- .../plans/logical/basicLogicalOperators.scala | 67 ++++++++++--------- .../sql/catalyst/parser/PlanParserSuite.scala | 3 +- 4 files changed, 85 insertions(+), 65 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8dbec408002f1..5bf2a297ed806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -217,10 +217,16 @@ class Analyzer( * Group Count: N + 1 (N is the number of group expressions) * * We need to get all of its subsets for the rule described above, the subset is - * represented as the bit masks. + * represented as sequence of expressions. */ - def bitmasks(r: Rollup): Seq[Int] = { - Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1) + def selectGroupExprsRollup(exprs: Seq[Expression]): Seq[Seq[Expression]] = { + if (exprs.length == 0) { + Seq(Seq.empty[Expression]) + } else { + selectGroupExprsRollup(exprs.drop(1)).map { expandExprs => + exprs.take(1) ++ expandExprs + } ++ Seq(Seq.empty[Expression]) + } } /* @@ -230,10 +236,17 @@ class Analyzer( * Group Count: 2 ^ N (N is the number of group expressions) * * We need to get all of its subsets for a given GROUPBY expression, the subsets are - * represented as the bit masks. + * represented as sequence of expressions. */ - def bitmasks(c: Cube): Seq[Int] = { - Seq.tabulate(1 << c.groupByExprs.length)(i => i) + def selectGroupExprsCube(exprs: Seq[Expression]): Seq[Seq[Expression]] = { + if (exprs.length == 0) { + Seq(Seq.empty[Expression]) + } else { + val expandExprsList = selectGroupExprsCube(exprs.drop(1)) + expandExprsList.map { expandExprs => + exprs.take(1) ++ expandExprs + } ++ expandExprsList + } } private def hasGroupingAttribute(expr: Expression): Boolean = { @@ -282,9 +295,11 @@ class Analyzer( s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => - GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) + GroupingSets( + selectGroupExprsCube(c.groupByExprs), groupByExprs, child, aggregateExpressions) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => - GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) + GroupingSets( + selectGroupExprsRollup(r.groupByExprs), groupByExprs, child, aggregateExpressions) // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => @@ -299,18 +314,29 @@ class Analyzer( case other => Alias(other, other.toString)() } - // The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases - // with 0 indicating this expression is in the grouping set. The following line of code - // calculates the bitmask representing the expressions that absent in at least one grouping - // set (indicated by 1). - val nullBitmask = x.bitmasks.reduce(_ | _) - - val attrLength = groupByAliases.length + // Change the nullability of group by aliases if necessary. For example, if we have + // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we + // should change the nullabilty of b to be TRUE. val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => - a.toAttribute.withNullability(((nullBitmask >> (attrLength - idx - 1)) & 1) == 1) + if (x.selectedGroupByExprs.exists(!_.contains(a.child))) { + a.toAttribute.withNullability(true) + } else { + a.toAttribute + } + } + + val groupingSetsAttributes = x.selectedGroupByExprs.map { groupingSetExprs => + groupingSetExprs.map { expr => + val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( + failAnalysis(s"$expr doesn't show up in the GROUP BY list")) + // Map alias to expanded attribute. + expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( + alias.toAttribute) + } } - val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) + val expand = Expand( + groupingSetsAttributes, groupByAliases, expandedAttributes, gid, x.child) val groupingAttrs = expand.output.drop(x.child.output.length) val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4b151c81d8f8b..c7f7af4e52004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -497,23 +497,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { if (GROUPING != null) { // GROUP BY .... GROUPING SETS (...) - val expressionMap = groupByExpressions.zipWithIndex.toMap - val numExpressions = expressionMap.size - val mask = (1 << numExpressions) - 1 - val masks = ctx.groupingSet.asScala.map { - _.expression.asScala.foldLeft(mask) { - case (bitmap, eCtx) => - // Find the index of the expression. - val e = typedVisit[Expression](eCtx) - val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( - throw new ParseException( - s"$e doesn't show up in the GROUP BY list", ctx)) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (numExpressions - 1 - index)) + val selectedGroupByExprs = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(Seq.empty[Expression]) { + case (exprs, eCtx) => + exprs :+ typedVisit[Expression](eCtx) } } - GroupingSets(masks, groupByExpressions, query, selectExpressions) + GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? val mappedGroupByExpressions = if (CUBE != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 65ceab2ce27b1..6a468866aa137 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import scala.collection.mutable.ArrayBuffer - +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTypes @@ -523,59 +522,64 @@ case class Window( object Expand { /** - * Extract attribute set according to the grouping id. + * Build bit mask from attributes of selected grouping set. * - * @param bitmask bitmask to represent the selected of the attribute sequence - * @param attrs the attributes in sequence - * @return the attributes of non selected specified via bitmask (with the bit set to 1) + * @param groupingSetAttrs The attributes of selected grouping set + * @param attrMap Mapping group by attributes to its index in attributes sequence + * @return The bitmask which represents the selected attributes out of group by attributes. + * A bit in the bitmask is corresponding to an attribute in group by attributes sequence, + * the selected attribute has corresponding bit set to 0 and otherwise set to 1. */ - private def buildNonSelectAttrSet( - bitmask: Int, - attrs: Seq[Attribute]): AttributeSet = { - val nonSelect = new ArrayBuffer[Attribute]() - - var bit = attrs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1) - bit -= 1 + private def buildBitmask( + groupingSetAttrs: Seq[Attribute], + attrMap: Map[Attribute, Int]): Int = { + val numAttributes = attrMap.size + val mask = (1 << numAttributes) - 1 + + groupingSetAttrs.foldLeft(mask) { + case (bitmap, attr) => + // Find the index of the attribute in sequence. + val index = attrMap.get(attr).getOrElse( + throw new AnalysisException(s"$attr doesn't show up in the GROUP BY list") + ) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numAttributes - 1 - index)) } - - AttributeSet(nonSelect) } /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for an input row. * - * @param bitmasks The bitmask set represents the grouping sets + * @param groupingSetsAttrs The attributes of grouping sets * @param groupByAliases The aliased original group by expressions * @param groupByAttrs The attributes of aliased group by expressions * @param gid Attribute of the grouping id * @param child Child operator */ def apply( - bitmasks: Seq[Int], + groupingSetsAttrs: Seq[Seq[Attribute]], groupByAliases: Seq[Alias], groupByAttrs: Seq[Attribute], gid: Attribute, child: LogicalPlan): Expand = { + val attrMap = groupByAttrs.zipWithIndex.toMap + // Create an array of Projections for the child projection, and replace the projections' // expressions which equal GroupBy expressions with Literal(null), if those expressions - // are not set for this grouping set (according to the bit mask). - val projections = bitmasks.map { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) - + // are not set for this grouping set. + val projections = groupingSetsAttrs.map { groupingSetAttrs => child.output ++ groupByAttrs.map { attr => - if (nonSelectedGroupAttrSet.contains(attr)) { + if (!groupingSetAttrs.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, attr.dataType) } else { attr } - // groupingId is the last output, here we use the bit mask as the concrete value for it. - } :+ Literal.create(bitmask, IntegerType) + // groupingId is the last output, here we use the bit mask as the concrete value for it. + } :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType) } // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original @@ -616,16 +620,15 @@ case class Expand( * * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer * - * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected - * GroupBy expressions - * @param groupByExprs The Group By expressions candidates, take effective only if the - * associated bit in the bitmask set to 1. + * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should + * exists in groupByExprs. + * @param groupByExprs The Group By expressions candidates. * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions */ case class GroupingSets( - bitmasks: Seq[Int], + selectedGroupByExprs: Seq[Seq[Expression]], groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression]) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 7400f3430e99c..64952e2ff8311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -233,7 +233,8 @@ class PlanParserSuite extends PlanTest { // Grouping Sets assertEqual(s"$sql grouping sets((a, b), (a), ())", - GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), + Seq('a, 'b, 'sum.function('c).as("c")))) intercept(s"$sql grouping sets((a, b), (c), ())", "c doesn't show up in the GROUP BY list") } From 66efa0061ea32d8774f705b8b0821bd6463ea306 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Fri, 14 Oct 2016 21:21:01 +0800 Subject: [PATCH 2/9] update testcases. --- .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../ResolveGroupingAnalyticsSuite.scala | 244 ++++++++++++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 2 - 3 files changed, 247 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5bf2a297ed806..705ffcea73234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -367,13 +367,14 @@ class Analyzer( Aggregate(groupingAttrs, aggregations, expand) - case f @ Filter(cond, child) if hasGroupingFunction(cond) => + case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => val groupingExprs = findGroupingExprs(child) // The unresolved grouping id will be resolved by ResolveMissingReferences val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) f.copy(condition = newCond) - case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => + case s @ Sort(order, _, child) + if order.exists(hasGroupingFunction) && order.forall(_.resolved) => val groupingExprs = findGroupingExprs(child) val gid = VirtualColumn.groupingIdAttribute // The unresolved grouping id will be resolved by ResolveMissingReferences diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala new file mode 100644 index 0000000000000..dcf19143f80c8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +class ResolveGroupingAnalyticsSuite extends AnalysisTest { + + lazy val a = 'a.int + lazy val b = 'b.string + lazy val c = 'c.string + lazy val unresolved_a = UnresolvedAttribute("a") + lazy val unresolved_b = UnresolvedAttribute("b") + lazy val unresolved_c = UnresolvedAttribute("c") + lazy val gid = 'spark_grouping_id.int.withNullability(false) + lazy val hive_gid = 'grouping__id.int.withNullability(false) + lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType) + lazy val nulInt = Literal(null, IntegerType) + lazy val nulStr = Literal(null, StringType) + lazy val r1 = LocalRelation(a, b, c) + + test("grouping sets") { + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = GroupingSets(Seq(), Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan2, expected2) + + val originalPlan3 = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b), + Seq(unresolved_c)), Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) + } + + test("cube") { + val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), + Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) + val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, 0)), + Seq(a, b, c, gid), + Project(Seq(a, b, c), r1))) + checkAnalysis(originalPlan2, expected2) + } + + test("rollup") { + val originalPlan = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) + val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, 0)), + Seq(a, b, c, gid), + Project(Seq(a, b, c), r1))) + checkAnalysis(originalPlan2, expected2) + } + + test("grouping function") { + // GrouingSets + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(Grouping(unresolved_a)))) + val expected = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Cube + val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(Grouping(unresolved_a))), r1) + val expected2 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), + Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan2, expected2) + + // Rollup + val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(Grouping(unresolved_a))), r1) + val expected3 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan3, expected3) + } + + test("grouping_id") { + // GrouingSets + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b))))) + val expected = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Cube + val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + val expected2 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), + Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan2, expected2) + + // Rollup + val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))), + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + val expected3 = Aggregate(Seq(a, b, gid), + Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), + Expand( + Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan3, expected3) + } + + test("filter with grouping function") { + // Filter with Grouping function + val originalPlan = Filter(Grouping(unresolved_a) === 0, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType) === 0, + Aggregate(Seq(a, b, gid), + Seq(a, b, gid), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan, expected) + + // Filter with GroupingID + val originalPlan2 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected2 = Project(Seq(a, b), Filter(gid === 1, + Aggregate(Seq(a, b, gid), + Seq(a, b, gid), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan2, expected2) + } + + test("sort with grouping function") { + // Sort with Grouping function + val originalPlan = Sort( + Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected = Project(Seq(a, b), Sort( + Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true, + Aggregate(Seq(a, b, gid), + Seq(a, b, grouping_a.as("aggOrder")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan, expected) + + // Sort with GroupingID + val originalPlan2 = Sort( + Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true, + GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) + val expected2 = Project(Seq(a, b), Sort( + Seq(SortOrder('aggOrder.int.withNullability(false), Ascending)), true, + Aggregate(Seq(a, b, gid), + Seq(a, b, gid.as("aggOrder")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + checkAnalysis(originalPlan2, expected2) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 64952e2ff8311..5f0f6ee479c69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -235,8 +235,6 @@ class PlanParserSuite extends PlanTest { assertEqual(s"$sql grouping sets((a, b), (a), ())", GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) - intercept(s"$sql grouping sets((a, b), (c), ())", - "c doesn't show up in the GROUP BY list") } test("limit") { From bf5d419e76486a93a45c88f3f16e2c87d623b2de Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Fri, 14 Oct 2016 21:51:41 +0800 Subject: [PATCH 3/9] update comments. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 ++ .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 705ffcea73234..efcddec8af14d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -367,12 +367,14 @@ class Analyzer( Aggregate(groupingAttrs, aggregations, expand) + // We should make sure all expressions in condition have been resolved. case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => val groupingExprs = findGroupingExprs(child) // The unresolved grouping id will be resolved by ResolveMissingReferences val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) f.copy(condition = newCond) + // We should make sure all [[SortOrder]]s have been resolved. case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) && order.forall(_.resolved) => val groupingExprs = findGroupingExprs(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6a468866aa137..075dd857f951f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -578,7 +578,7 @@ object Expand { } else { attr } - // groupingId is the last output, here we use the bit mask as the concrete value for it. + // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType) } From 16bd22b5232cd7918e49f587763e970cad88108a Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sat, 15 Oct 2016 00:17:04 +0800 Subject: [PATCH 4/9] add test cases. --- .../ResolveGroupingAnalyticsSuite.scala | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index dcf19143f80c8..0b36bbdc83ab8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -196,18 +196,28 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) checkAnalysis(originalPlan, expected) + val originalPlan2 = Filter(Grouping(unresolved_a) === 0, + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan2, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) + // Filter with GroupingID - val originalPlan2 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1, + val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1, GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) - val expected2 = Project(Seq(a, b), Filter(gid === 1, + val expected3 = Project(Seq(a, b), Filter(gid === 1, Aggregate(Seq(a, b, gid), Seq(a, b, gid), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), Seq(a, b, c, a, b, gid), Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) - checkAnalysis(originalPlan2, expected2) + checkAnalysis(originalPlan3, expected3) + + val originalPlan4 = Filter(GroupingID(Seq(unresolved_a)) === 1, + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan4, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) } test("sort with grouping function") { @@ -226,12 +236,17 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) checkAnalysis(originalPlan, expected) + val originalPlan2 = Sort(Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan2, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) + // Sort with GroupingID - val originalPlan2 = Sort( + val originalPlan3 = Sort( Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true, GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) - val expected2 = Project(Seq(a, b), Sort( + val expected3 = Project(Seq(a, b), Sort( Seq(SortOrder('aggOrder.int.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), Seq(a, b, gid.as("aggOrder")), @@ -239,6 +254,12 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), Seq(a, b, c, a, b, gid), Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) - checkAnalysis(originalPlan2, expected2) + checkAnalysis(originalPlan3, expected3) + + val originalPlan4 = Sort( + Seq(SortOrder(GroupingID(Seq(unresolved_a)), Ascending)), true, + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) + assertAnalysisError(originalPlan4, + Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) } } From 406060a221f6b20e06ff0d48dfd9d0fd7961655a Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Thu, 20 Oct 2016 00:41:44 +0800 Subject: [PATCH 5/9] improve code quality. --- .../sql/catalyst/analysis/Analyzer.scala | 33 +++++++++---------- .../sql/catalyst/parser/AstBuilder.scala | 11 ++----- .../plans/logical/basicLogicalOperators.scala | 16 +++------ 3 files changed, 23 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index efcddec8af14d..b4bf3e5c48a13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -219,14 +219,14 @@ class Analyzer( * We need to get all of its subsets for the rule described above, the subset is * represented as sequence of expressions. */ - def selectGroupExprsRollup(exprs: Seq[Expression]): Seq[Seq[Expression]] = { - if (exprs.length == 0) { - Seq(Seq.empty[Expression]) - } else { - selectGroupExprsRollup(exprs.drop(1)).map { expandExprs => - exprs.take(1) ++ expandExprs - } ++ Seq(Seq.empty[Expression]) + def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = { + val buffer = ArrayBuffer.empty[Seq[Expression]] + var current = exprs + while (current.nonEmpty) { + buffer += current + current = current.init } + buffer } /* @@ -238,15 +238,12 @@ class Analyzer( * We need to get all of its subsets for a given GROUPBY expression, the subsets are * represented as sequence of expressions. */ - def selectGroupExprsCube(exprs: Seq[Expression]): Seq[Seq[Expression]] = { - if (exprs.length == 0) { - Seq(Seq.empty[Expression]) - } else { - val expandExprsList = selectGroupExprsCube(exprs.drop(1)) - expandExprsList.map { expandExprs => - exprs.take(1) ++ expandExprs - } ++ expandExprsList - } + def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs match { + case x :: xs => + val initial = cubeExprs(xs) + initial.map(x +: _) ++ initial + case Nil => + Seq(Seq.empty) } private def hasGroupingAttribute(expr: Expression): Boolean = { @@ -296,10 +293,10 @@ class Analyzer( case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => GroupingSets( - selectGroupExprsCube(c.groupByExprs), groupByExprs, child, aggregateExpressions) + cubeExprs(c.groupByExprs), groupByExprs, child, aggregateExpressions) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => GroupingSets( - selectGroupExprsRollup(r.groupByExprs), groupByExprs, child, aggregateExpressions) + rollupExprs(r.groupByExprs), groupByExprs, child, aggregateExpressions) // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c7f7af4e52004..16b4c23b4ecd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -492,17 +492,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ctx: AggregationContext, selectExpressions: Seq[NamedExpression], query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - val groupByExpressions = expressionList(groupingExpressions) + val groupByExpressions = expressionList(ctx.groupingExpressions) if (GROUPING != null) { // GROUP BY .... GROUPING SETS (...) - val selectedGroupByExprs = ctx.groupingSet.asScala.map { - _.expression.asScala.foldLeft(Seq.empty[Expression]) { - case (exprs, eCtx) => - exprs :+ typedVisit[Expression](eCtx) - } - } + val selectedGroupByExprs = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e))) GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 075dd857f951f..1760bed66af04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -535,17 +535,11 @@ object Expand { attrMap: Map[Attribute, Int]): Int = { val numAttributes = attrMap.size val mask = (1 << numAttributes) - 1 - - groupingSetAttrs.foldLeft(mask) { - case (bitmap, attr) => - // Find the index of the attribute in sequence. - val index = attrMap.get(attr).getOrElse( - throw new AnalysisException(s"$attr doesn't show up in the GROUP BY list") - ) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (numAttributes - 1 - index)) - } + groupingSetAttrs.map(attrMap).map(index => + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + ~(1 << (numAttributes - 1 - index)) + ).reduce(_ & _) & mask } /** From 481a6acd4a8742b4760cfa92270675abccff6294 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Thu, 20 Oct 2016 17:08:38 +0800 Subject: [PATCH 6/9] refactor ResolveGroupingAnalytics --- .../sql/catalyst/analysis/Analyzer.scala | 214 +++++++++++------- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../plans/logical/basicLogicalOperators.scala | 4 +- 3 files changed, 131 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b4bf3e5c48a13..ec57627cbbafd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -226,7 +226,7 @@ class Analyzer( buffer += current current = current.init } - buffer + buffer += Seq.empty } /* @@ -266,17 +266,17 @@ class Analyzer( expr transform { case e: GroupingID => if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { - gid + Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + s"grouping columns (${groupByExprs.mkString(",")})") } - case Grouping(col: Expression) => + case e @ Grouping(col: Expression) => val idx = groupByExprs.indexOf(col) if (idx >= 0) { - Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), - Literal(1)), ByteType) + Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), + Literal(1)), ByteType), toPrettySQL(e))() } else { throw new AnalysisException(s"Column of grouping ($col) can't be found " + s"in grouping columns ${groupByExprs.mkString(",")}") @@ -284,6 +284,124 @@ class Analyzer( } } + /* + * Create new alias for all group by expressions for `Expand` operator. + */ + private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = { + groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } + } + + /* + * Construct [[Expand]] operator with grouping sets. + */ + private def constructExpand( + selectedGroupByExprs: Seq[Seq[Expression]], + child: LogicalPlan, + groupByAliases: Seq[Alias], + gid: Attribute): LogicalPlan = { + // Change the nullability of group by aliases if necessary. For example, if we have + // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we + // should change the nullabilty of b to be TRUE. + // TODO: For Cube/Rollup just set nullability to be `true`. + val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + if (selectedGroupByExprs.exists(!_.contains(a.child))) { + a.toAttribute.withNullability(true) + } else { + a.toAttribute + } + } + + val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs => + groupingSetExprs.map { expr => + val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( + failAnalysis(s"$expr doesn't show up in the GROUP BY list")) + // Map alias to expanded attribute. + expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( + alias.toAttribute) + } + } + + Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child) + } + + /* + * Construct new aggregate expressions by replacing grouping functions. + */ + private def constructAggregateExprs( + groupByExprs: Seq[Expression], + aggregations: Seq[NamedExpression], + groupByAliases: Seq[Alias], + groupingAttrs: Seq[Expression], + gid: Attribute): Seq[NamedExpression] = aggregations.map { case expr => + // collect all the found AggregateExpression, so we can check an expression is part of + // any AggregateExpression or not. + val aggsBuffer = ArrayBuffer[Expression]() + // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. + def isPartOfAggregation(e: Expression): Boolean = { + aggsBuffer.exists(a => a.find(_ eq e).isDefined) + } + replaceGroupingFunc(expr, groupByExprs, gid).transformDown { + // AggregateExpression should be computed on the unmodified value of its argument + // expressions, so we should not replace any references to grouping expression + // inside it. + case e: AggregateExpression => + aggsBuffer += e + e + case e if isPartOfAggregation(e) => e + case e => + // Replace expression by expand output attribute. + val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) + if (index == -1) { + e + } else { + groupingAttrs(index) + } + }.asInstanceOf[NamedExpression] + } + + /* + * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. + */ + private def constructAggregate( + selectedGroupByExprs: Seq[Seq[Expression]], + groupByExprs: Seq[Expression], + aggregationExprs: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + + // Expand works by setting grouping expressions to null as determined by the + // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate + // instead of the original value we need to create new aliases for all group by expressions + // that will only be used for the intended purpose. + val groupByAliases = constructGroupByAlias(groupByExprs) + + val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) + val groupingAttrs = expand.output.drop(child.output.length) + + val aggregations = constructAggregateExprs( + groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) + + Aggregate(groupingAttrs, aggregations, expand) + } + + private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { + plan.collectFirst { + case a: Aggregate => + // this Aggregate should have grouping id as the last grouping key. + val gid = a.groupingExpressions.last + if (!gid.isInstanceOf[AttributeReference] + || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + a.groupingExpressions.take(a.groupingExpressions.length - 1) + }.getOrElse { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + } + // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. @@ -292,77 +410,12 @@ class Analyzer( s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => - GroupingSets( - cubeExprs(c.groupByExprs), groupByExprs, child, aggregateExpressions) + constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => - GroupingSets( - rollupExprs(r.groupByExprs), groupByExprs, child, aggregateExpressions) - + constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child) // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => - val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - - // Expand works by setting grouping expressions to null as determined by the bitmasks. To - // prevent these null values from being used in an aggregate instead of the original value - // we need to create new aliases for all group by expressions that will only be used for - // the intended purpose. - val groupByAliases: Seq[Alias] = x.groupByExprs.map { - case e: NamedExpression => Alias(e, e.name)() - case other => Alias(other, other.toString)() - } - - // Change the nullability of group by aliases if necessary. For example, if we have - // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we - // should change the nullabilty of b to be TRUE. - val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => - if (x.selectedGroupByExprs.exists(!_.contains(a.child))) { - a.toAttribute.withNullability(true) - } else { - a.toAttribute - } - } - - val groupingSetsAttributes = x.selectedGroupByExprs.map { groupingSetExprs => - groupingSetExprs.map { expr => - val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( - failAnalysis(s"$expr doesn't show up in the GROUP BY list")) - // Map alias to expanded attribute. - expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( - alias.toAttribute) - } - } - - val expand = Expand( - groupingSetsAttributes, groupByAliases, expandedAttributes, gid, x.child) - val groupingAttrs = expand.output.drop(x.child.output.length) - - val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => - // collect all the found AggregateExpression, so we can check an expression is part of - // any AggregateExpression or not. - val aggsBuffer = ArrayBuffer[Expression]() - // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. - def isPartOfAggregation(e: Expression): Boolean = { - aggsBuffer.exists(a => a.find(_ eq e).isDefined) - } - replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { - // AggregateExpression should be computed on the unmodified value of its argument - // expressions, so we should not replace any references to grouping expression - // inside it. - case e: AggregateExpression => - aggsBuffer += e - e - case e if isPartOfAggregation(e) => e - case e => - val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) - if (index == -1) { - e - } else { - groupingAttrs(index) - } - }.asInstanceOf[NamedExpression] - } - - Aggregate(groupingAttrs, aggregations, expand) + constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child) // We should make sure all expressions in condition have been resolved. case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => @@ -380,21 +433,6 @@ class Analyzer( val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) s.copy(order = newOrder) } - - private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { - plan.collectFirst { - case a: Aggregate => - // this Aggregate should have grouping id as the last grouping key. - val gid = a.groupingExpressions.last - if (!gid.isInstanceOf[AttributeReference] - || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { - failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - } - a.groupingExpressions.take(a.groupingExpressions.length - 1) - }.getOrElse { - failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - } - } } object ResolvePivot extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 16b4c23b4ecd5..2c4db0d2c3425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -494,16 +494,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { query: LogicalPlan): LogicalPlan = withOrigin(ctx) { val groupByExpressions = expressionList(ctx.groupingExpressions) - if (GROUPING != null) { + if (ctx.GROUPING != null) { // GROUP BY .... GROUPING SETS (...) val selectedGroupByExprs = ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e))) GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? - val mappedGroupByExpressions = if (CUBE != null) { + val mappedGroupByExpressions = if (ctx.CUBE != null) { Seq(Cube(groupByExpressions)) - } else if (ROLLUP != null) { + } else if (ctx.ROLLUP != null) { Seq(Rollup(groupByExpressions)) } else { groupByExpressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1760bed66af04..28dddd2062e08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -535,11 +535,11 @@ object Expand { attrMap: Map[Attribute, Int]): Int = { val numAttributes = attrMap.size val mask = (1 << numAttributes) - 1 - groupingSetAttrs.map(attrMap).map(index => + (mask +: groupingSetAttrs.map(attrMap).map(index => // 0 means that the column at the given index is a grouping column, 1 means it is not, // so we unset the bit in bitmap. ~(1 << (numAttributes - 1 - index)) - ).reduce(_ & _) & mask + )).reduce(_ & _) } /** From 55873435a470b64dcb32b3a92552b4a84ebbdc27 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Fri, 21 Oct 2016 14:09:50 +0800 Subject: [PATCH 7/9] bugfix --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ec57627cbbafd..deefed324987c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -238,7 +238,7 @@ class Analyzer( * We need to get all of its subsets for a given GROUPBY expression, the subsets are * represented as sequence of expressions. */ - def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs match { + def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match { case x :: xs => val initial = cubeExprs(xs) initial.map(x +: _) ++ initial @@ -409,9 +409,12 @@ class Analyzer( failAnalysis( s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") - case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + // Ensure group by expressions and aggregate expressions have been resolved. + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) + if (groupByExprs ++ aggregateExpressions).forall(_.resolved) => constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child) - case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) + if (groupByExprs ++ aggregateExpressions).forall(_.resolved) => constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child) // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => From 4df9a426feb63fc6aac36701cc91d204afa7a016 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 23 Oct 2016 11:52:28 +0800 Subject: [PATCH 8/9] add comment; simplify some case conditions. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 +++++++------- .../plans/logical/basicLogicalOperators.scala | 16 +++++++++++----- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index deefed324987c..491f09d60fb0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -306,18 +306,18 @@ class Analyzer( // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we // should change the nullabilty of b to be TRUE. // TODO: For Cube/Rollup just set nullability to be `true`. - val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => - if (selectedGroupByExprs.exists(!_.contains(a.child))) { - a.toAttribute.withNullability(true) + val expandedAttributes = groupByAliases.map { alias => + if (selectedGroupByExprs.exists(!_.contains(alias.child))) { + alias.toAttribute.withNullability(true) } else { - a.toAttribute + alias.toAttribute } } val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs => groupingSetExprs.map { expr => val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( - failAnalysis(s"$expr doesn't show up in the GROUP BY list")) + failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases")) // Map alias to expanded attribute. expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( alias.toAttribute) @@ -335,7 +335,7 @@ class Analyzer( aggregations: Seq[NamedExpression], groupByAliases: Seq[Alias], groupingAttrs: Seq[Expression], - gid: Attribute): Seq[NamedExpression] = aggregations.map { case expr => + gid: Attribute): Seq[NamedExpression] = aggregations.map { // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. val aggsBuffer = ArrayBuffer[Expression]() @@ -343,7 +343,7 @@ class Analyzer( def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } - replaceGroupingFunc(expr, groupByExprs, gid).transformDown { + replaceGroupingFunc(_, groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 28dddd2062e08..dcae7b026f58c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -522,24 +522,30 @@ case class Window( object Expand { /** - * Build bit mask from attributes of selected grouping set. + * Build bit mask from attributes of selected grouping set. A bit in the bitmask is corresponding + * to an attribute in group by attributes sequence, the selected attribute has corresponding bit + * set to 0 and otherwise set to 1. For example, if we have GroupBy attributes (a, b, c, d), the + * bitmask 5(whose binary form is 0101) represents grouping set (a, c). * * @param groupingSetAttrs The attributes of selected grouping set * @param attrMap Mapping group by attributes to its index in attributes sequence * @return The bitmask which represents the selected attributes out of group by attributes. - * A bit in the bitmask is corresponding to an attribute in group by attributes sequence, - * the selected attribute has corresponding bit set to 0 and otherwise set to 1. */ private def buildBitmask( groupingSetAttrs: Seq[Attribute], attrMap: Map[Attribute, Int]): Int = { val numAttributes = attrMap.size val mask = (1 << numAttributes) - 1 - (mask +: groupingSetAttrs.map(attrMap).map(index => + // Calculate the attrbute masks of selected grouping set. For example, if we have GroupBy + // attributes (a, b, c, d), grouping set (a, c) will produce the following sequence: + // (15, 7, 13), whose binary form is (1111, 0111, 1101) + val masks = (mask +: groupingSetAttrs.map(attrMap).map(index => // 0 means that the column at the given index is a grouping column, 1 means it is not, // so we unset the bit in bitmap. ~(1 << (numAttributes - 1 - index)) - )).reduce(_ & _) + )) + // Reduce masks to generate an bitmask for the selected grouping set. + masks.reduce(_ & _) } /** From ef3a733590f27e46993b93831585b083c2f83044 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 7 Nov 2016 00:20:30 +0800 Subject: [PATCH 9/9] add test cases for cubeExprs and rollupExprs --- .../sql/catalyst/analysis/Analyzer.scala | 10 +------ .../ResolveGroupingAnalyticsSuite.scala | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 491f09d60fb0e..dd68d60d3e839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -219,15 +219,7 @@ class Analyzer( * We need to get all of its subsets for the rule described above, the subset is * represented as sequence of expressions. */ - def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = { - val buffer = ArrayBuffer.empty[Seq[Expression]] - var current = exprs - while (current.nonEmpty) { - buffer += current - current = current.init - } - buffer += Seq.empty - } + def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toSeq /* * GROUP BY a, b, c WITH CUBE diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 0b36bbdc83ab8..2a0205bdc90fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -37,6 +37,32 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { lazy val nulStr = Literal(null, StringType) lazy val r1 = LocalRelation(a, b, c) + test("rollupExprs") { + val testRollup = (exprs: Seq[Expression], rollup: Seq[Seq[Expression]]) => { + val result = SimpleAnalyzer.ResolveGroupingAnalytics.rollupExprs(exprs) + assert(result.sortBy(_.hashCode) == rollup.sortBy(_.hashCode)) + } + + testRollup(Seq(a, b, c), Seq(Seq(), Seq(a), Seq(a, b), Seq(a, b, c))) + testRollup(Seq(c, b, a), Seq(Seq(), Seq(c), Seq(c, b), Seq(c, b, a))) + testRollup(Seq(a), Seq(Seq(), Seq(a))) + testRollup(Seq(), Seq(Seq())) + } + + test("cubeExprs") { + val testCube = (exprs: Seq[Expression], cube: Seq[Seq[Expression]]) => { + val result = SimpleAnalyzer.ResolveGroupingAnalytics.cubeExprs(exprs) + assert(result.sortBy(_.hashCode) == cube.sortBy(_.hashCode)) + } + + testCube(Seq(a, b, c), + Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(a, b), Seq(a, c), Seq(b, c), Seq(a, b, c))) + testCube(Seq(c, b, a), + Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(c, b), Seq(c, a), Seq(b, a), Seq(c, b, a))) + testCube(Seq(a), Seq(Seq(), Seq(a))) + testCube(Seq(), Seq(Seq())) + } + test("grouping sets") { val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1,