Skip to content

Commit dd5cb0a

Browse files
wangyanghvanhovell
authored andcommitted
[SPARK-17849][SQL] Fix NPE problem when using grouping sets
## What changes were proposed in this pull request? Prior this pr, the following code would cause an NPE: `case class point(a:String, b:String, c:String, d: Int)` `val data = Seq( point("1","2","3", 1), point("4","5","6", 1), point("7","8","9", 1) )` `sc.parallelize(data).toDF().registerTempTable("table")` `spark.sql("select a, b, c, count(d) from table group by a, b, c GROUPING SETS ((a)) ").show()` The reason is that when the grouping_id() behavior was changed in #10677, some code (which should be changed) was left out. Take the above code for example, prior #10677, the bit mask for set "(a)" was `001`, while after #10677 the bit mask was changed to `011`. However, the `nonNullBitmask` was not changed accordingly. This pr will fix this problem. ## How was this patch tested? add integration tests Author: wangyang <[email protected]> Closes #15416 from yangw1234/groupingid. (cherry picked from commit fb0d608) Signed-off-by: Herman van Hovell <[email protected]>
1 parent 5b9eb42 commit dd5cb0a

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,15 @@ class Analyzer(
305305
case other => Alias(other, other.toString)()
306306
}
307307

308-
val nonNullBitmask = x.bitmasks.reduce(_ & _)
308+
// The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases
309+
// with 0 indicating this expression is in the grouping set. The following line of code
310+
// calculates the bitmask representing the expressions that absent in at least one grouping
311+
// set (indicated by 1).
312+
val nullBitmask = x.bitmasks.reduce(_ | _)
309313

314+
val attrLength = groupByAliases.length
310315
val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
311-
a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
316+
a.toAttribute.withNullability(((nullBitmask >> (attrLength - idx - 1)) & 1) == 1)
312317
}
313318

314319
val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
CREATE TEMPORARY VIEW grouping AS SELECT * FROM VALUES
2+
("1", "2", "3", 1),
3+
("4", "5", "6", 1),
4+
("7", "8", "9", 1)
5+
as grouping(a, b, c, d);
6+
7+
-- SPARK-17849: grouping set throws NPE #1
8+
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS (());
9+
10+
-- SPARK-17849: grouping set throws NPE #2
11+
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a));
12+
13+
-- SPARK-17849: grouping set throws NPE #3
14+
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c));
15+
16+
17+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- Number of queries: 4
3+
4+
5+
-- !query 0
6+
CREATE TEMPORARY VIEW grouping AS SELECT * FROM VALUES
7+
("1", "2", "3", 1),
8+
("4", "5", "6", 1),
9+
("7", "8", "9", 1)
10+
as grouping(a, b, c, d)
11+
-- !query 0 schema
12+
struct<>
13+
-- !query 0 output
14+
15+
16+
17+
-- !query 1
18+
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS (())
19+
-- !query 1 schema
20+
struct<a:string,b:string,c:string,count(d):bigint>
21+
-- !query 1 output
22+
NULL NULL NULL 3
23+
24+
25+
-- !query 2
26+
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a))
27+
-- !query 2 schema
28+
struct<a:string,b:string,c:string,count(d):bigint>
29+
-- !query 2 output
30+
1 NULL NULL 1
31+
4 NULL NULL 1
32+
7 NULL NULL 1
33+
34+
35+
-- !query 3
36+
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c))
37+
-- !query 3 schema
38+
struct<a:string,b:string,c:string,count(d):bigint>
39+
-- !query 3 output
40+
NULL NULL 3 1
41+
NULL NULL 6 1
42+
NULL NULL 9 1

0 commit comments

Comments
 (0)