From b5ada3feb7d243859714c04ec4fb8c225c1781e0 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 18 Jul 2018 22:12:33 -0700 Subject: [PATCH 1/5] [SPARK-24424] Support ANSI-SQL compliant syntax for GROUPING SET --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 21 +++- .../ResolveGroupingAnalyticsSuite.scala | 34 +++++ .../sql-tests/inputs/grouping_set.sql | 34 +++++ .../sql-tests/results/grouping_set.sql.out | 117 +++++++++++++++++- 5 files changed, 204 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1b43874af6feb..3c9c34840f03f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -400,12 +400,12 @@ hintStatement fromClause : FROM relation (',' relation)* lateralView* pivotClause? ; - aggregation : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( WITH kind=ROLLUP | WITH kind=CUBE | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + | GROUP BY kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')' ; groupingSet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 59c371eb1557b..3868e7a797447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -440,19 +440,36 @@ class Analyzer( groupByExprs: Seq[Expression], aggregationExprs: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { + + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + val finalGroupByExpressions = if (groupByExprs == Nil) { + selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => + // Only unique expressions are included in the group by expressions and is determined + // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results + // in grouping expression (a * b) + if (result.find(_.semanticEquals(currentExpr)).isDefined) { + result + } else { + result :+ currentExpr + } + } + } else { + groupByExprs + } + // Expand works by setting grouping expressions to null as determined by the // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate // instead of the original value we need to create new aliases for all group by expressions // that will only be used for the intended purpose. - val groupByAliases = constructGroupByAlias(groupByExprs) + val groupByAliases = constructGroupByAlias(finalGroupByExpressions) val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) val groupingAttrs = expand.output.drop(child.output.length) val aggregations = constructAggregateExprs( - groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) + finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid) Aggregate(groupingAttrs, aggregations, expand) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 553b1598e7750..2a93b5b110029 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -91,6 +91,40 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) } + test("grouping sets with no explicit group by expressions") { + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Nil, r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + val originalPlan2 = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Nil, r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + checkAnalysis(originalPlan2, expected) + + + // Computation of grouping expression should remove duplicate expression based on their + // semantics (semanticEqual). + val originalPlan3 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), + Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil, r1, + Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), + unresolved_b, UnresolvedAlias(count(unresolved_c)))) + + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan3) + val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions + assert(gExpressions.size == 3) + val firstGroupingExprAttrName = + gExpressions(0).asInstanceOf[AttributeReference].name.replaceAll("#[0-9]*", "#0") + assert(firstGroupingExprAttrName == "(a#0 * 2)") + assert(gExpressions(1).asInstanceOf[AttributeReference].name == "b") + assert(gExpressions(2).asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName) + } + test("cube") { val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 3594283505280..cc8e5c1773d32 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -13,5 +13,39 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a)); -- SPARK-17849: grouping set throws NPE #3 SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c)); +-- Group sets without explicit group by +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); +-- Group sets without group by and with grouping +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); + +-- Mutiple grouping within a grouping set +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1; + +-- Group sets without explicit group by +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY c1,c2 GROUPING SETS (c1,c2); + +-- Mutiple grouping within a grouping set +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)); + +-- more query constructs with grouping sets +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1; + +-- negative tests - must have at least one grouping expression +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP; + +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE; diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index edb38a52b7514..e665201635390 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 14 -- !query 0 @@ -40,3 +40,118 @@ struct NULL NULL 3 1 NULL NULL 6 1 NULL NULL 9 1 + + +-- !query 4 +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 4 schema +struct +-- !query 4 output +x 10 +y 20 + + +-- !query 5 +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 5 schema +struct +-- !query 5 output +x 10 0 +y 20 0 + + +-- !query 6 +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1 +-- !query 6 schema +struct +-- !query 6 output +NULL a 10 2 +NULL b 20 2 + + +-- !query 7 +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY c1,c2 GROUPING SETS (c1,c2) +-- !query 7 schema +struct +-- !query 7 output +0 +0 +1 +1 + + +-- !query 8 +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)) +-- !query 8 schema +struct +-- !query 8 output +-1 +-1 +-3 +-3 + + +-- !query 9 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)) +-- !query 9 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 9 output +2 NULL 1 +4 NULL 2 +NULL 1 1 +NULL 2 2 + + +-- !query 10 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)) +-- !query 10 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 10 output +2 NULL 2 +4 NULL 4 +NULL 1 1 +NULL 2 2 + + +-- !query 11 +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1 +-- !query 11 schema +struct +-- !query 11 output +3 2 +1 2 + + +-- !query 12 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'ROLLUP' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-----------------------------------------------------^^^ + + +-- !query 13 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'CUBE' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-----------------------------------------------------^^^ From ac8f04fe26ad48bbd51754bf257da7e52866d87a Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Jul 2018 11:04:20 -0700 Subject: [PATCH 2/5] code review --- .../main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 3c9c34840f03f..2aca10f1bfbc7 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -400,6 +400,7 @@ hintStatement fromClause : FROM relation (',' relation)* lateralView* pivotClause? ; + aggregation : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( WITH kind=ROLLUP diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3868e7a797447..1a03f6ad06017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -440,8 +440,6 @@ class Analyzer( groupByExprs: Seq[Expression], aggregationExprs: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { - - val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() val finalGroupByExpressions = if (groupByExprs == Nil) { From 7cf187db02a54bcfd3b44e0710d95462b273ea97 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Jul 2018 13:49:35 -0700 Subject: [PATCH 3/5] Remove redundant test --- .../analysis/ResolveGroupingAnalyticsSuite.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 2a93b5b110029..8da4d7e3aa372 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -102,20 +102,14 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) - val originalPlan2 = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Nil, r1, - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) - checkAnalysis(originalPlan2, expected) - - // Computation of grouping expression should remove duplicate expression based on their // semantics (semanticEqual). - val originalPlan3 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), + val originalPlan2 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil, r1, Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), unresolved_b, UnresolvedAlias(count(unresolved_c)))) - val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan3) + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions assert(gExpressions.size == 3) val firstGroupingExprAttrName = From e0c57f73ec4c3e24e4af107cb457c9b1c13f4174 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Jul 2018 16:41:19 -0700 Subject: [PATCH 4/5] Add comment --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1a03f6ad06017..b682a30f437b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -442,6 +442,9 @@ class Analyzer( child: LogicalPlan): LogicalPlan = { val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and + // can be null. In such case, we derive the groupByExprs from the user supplied values for + // grouping sets. val finalGroupByExpressions = if (groupByExprs == Nil) { selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => // Only unique expressions are included in the group by expressions and is determined From 2ecf3e183e51f13ec9c92692d29854f01eb327c5 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Jul 2018 19:35:19 -0700 Subject: [PATCH 5/5] Add one more negative test --- .../resources/sql-tests/inputs/grouping_set.sql | 4 +++- .../sql-tests/results/grouping_set.sql.out | 13 +++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index cc8e5c1773d32..6bbde9f38d657 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -26,7 +26,7 @@ GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) HAVING GROUPING__ID > 1; -- Group sets without explicit group by -SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY c1,c2 GROUPING SETS (c1,c2); +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2); -- Mutiple grouping within a grouping set SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)); @@ -49,3 +49,5 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP; SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE; +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()); + diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index e665201635390..34ab09c5e3bba 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 15 -- !query 0 @@ -73,7 +73,7 @@ NULL b 20 2 -- !query 7 -SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY c1,c2 GROUPING SETS (c1,c2) +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2) -- !query 7 schema struct -- !query 7 output @@ -155,3 +155,12 @@ extraneous input 'CUBE' expecting (line 1, pos 53) == SQL == SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE -----------------------------------------------------^^^ + + +-- !query 14 +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;