Skip to content

Commit 71c73d5

Browse files
committed
[SPARK-30279][SQL] Support 32 or more grouping attributes for GROUPING_ID
### What changes were proposed in this pull request? This pr intends to support 32 or more grouping attributes for GROUPING_ID. In the current master, an integer overflow can occur to compute grouping IDs; https://github.com/apache/spark/blob/e75d9afb2f282ce79c9fd8bce031287739326a4f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala#L613 For example, the query below generates wrong grouping IDs in the master; ``` scala> val numCols = 32 // or, 31 scala> val cols = (0 until numCols).map { i => s"c$i" } scala> sql(s"create table test_$numCols (${cols.map(c => s"$c int").mkString(",")}, v int) using parquet") scala> val insertVals = (0 until numCols).map { _ => 1 }.mkString(",") scala> sql(s"insert into test_$numCols values ($insertVals,3)") scala> sql(s"select grouping_id(), sum(v) from test_$numCols group by grouping sets ((${cols.mkString(",")}), (${cols.init.mkString(",")}))").show(10, false) scala> sql(s"drop table test_$numCols") // numCols = 32 +-------------+------+ |grouping_id()|sum(v)| +-------------+------+ |0 |3 | |0 |3 | // Wrong Grouping ID +-------------+------+ // numCols = 31 +-------------+------+ |grouping_id()|sum(v)| +-------------+------+ |0 |3 | |1 |3 | +-------------+------+ ``` To fix this issue, this pr change code to use long values for `GROUPING_ID` instead of int values. ### Why are the changes needed? To support more cases in `GROUPING_ID`. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added unit tests. Closes apache#26918 from maropu/FixGroupingIdIssue. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Takeshi Yamamuro <[email protected]>
1 parent 2e3adad commit 71c73d5

File tree

11 files changed

+117
-53
lines changed

11 files changed

+117
-53
lines changed

docs/sql-migration-guide.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ license: |
2222
* Table of contents
2323
{:toc}
2424

25+
## Upgrading from Spark SQL 3.0 to 3.1
26+
- Since Spark 3.1, grouping_id() returns long values. In Spark version 3.0 and earlier, this function returns int values. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.integerGroupingId` to `true`.
27+
2528
## Upgrading from Spark SQL 2.4 to 3.0
2629
- Since Spark 3.0, when inserting a value into a table column with a different data type, the type coercion is performed as per ANSI SQL standard. Certain unreasonable type conversions such as converting `string` to `int` and `double` to `boolean` are disallowed. A runtime exception will be thrown if the value is out-of-range for the data type of the column. In Spark version 2.4 and earlier, type conversions during table insertion are allowed as long as they are valid `Cast`. When inserting an out-of-range value to a integral field, the low-order bits of the value is inserted(the same as Java/Scala numeric type casting). For example, if 257 is inserted to a field of byte type, the result is 1. The behavior is controlled by the option `spark.sql.storeAssignmentPolicy`, with a default value as "ANSI". Setting the option as "Legacy" restores the previous behavior.
2730

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class Analyzer(
437437
val idx = groupByExprs.indexWhere(_.semanticEquals(col))
438438
if (idx >= 0) {
439439
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
440-
Literal(1)), ByteType), toPrettySQL(e))()
440+
Literal(1L)), ByteType), toPrettySQL(e))()
441441
} else {
442442
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
443443
s"in grouping columns ${groupByExprs.mkString(",")}")
@@ -531,8 +531,6 @@ class Analyzer(
531531
groupByExprs: Seq[Expression],
532532
aggregationExprs: Seq[NamedExpression],
533533
child: LogicalPlan): LogicalPlan = {
534-
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
535-
536534
// In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
537535
// can be null. In such case, we derive the groupByExprs from the user supplied values for
538536
// grouping sets.
@@ -551,12 +549,18 @@ class Analyzer(
551549
groupByExprs
552550
}
553551

552+
if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) {
553+
throw new AnalysisException(
554+
s"Grouping sets size cannot be greater than ${GroupingID.dataType.defaultSize * 8}")
555+
}
556+
554557
// Expand works by setting grouping expressions to null as determined by the
555558
// `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
556559
// instead of the original value we need to create new aliases for all group by expressions
557560
// that will only be used for the intended purpose.
558561
val groupByAliases = constructGroupByAlias(finalGroupByExpressions)
559562

563+
val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
560564
val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid)
561565
val groupingAttrs = expand.output.drop(child.output.length)
562566

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
22+
import org.apache.spark.sql.internal.SQLConf
2223
import org.apache.spark.sql.types._
2324

2425
/**
@@ -49,11 +50,11 @@ trait GroupingSet extends Expression with CodegenFallback {
4950
> SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY _FUNC_(name, age);
5051
Bob 5 1
5152
Alice 2 1
52-
NULL NULL 2
53-
NULL 5 1
54-
Bob NULL 1
5553
Alice NULL 1
5654
NULL 2 1
55+
NULL NULL 2
56+
Bob NULL 1
57+
NULL 5 1
5758
""",
5859
since = "2.0.0")
5960
// scalastyle:on line.size.limit line.contains.tab
@@ -70,9 +71,9 @@ case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}
7071
> SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY _FUNC_(name, age);
7172
Bob 5 1
7273
Alice 2 1
74+
Alice NULL 1
7375
NULL NULL 2
7476
Bob NULL 1
75-
Alice NULL 1
7677
""",
7778
since = "2.0.0")
7879
// scalastyle:on line.size.limit line.contains.tab
@@ -91,8 +92,8 @@ case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}
9192
examples = """
9293
Examples:
9394
> SELECT name, _FUNC_(name), sum(age) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY cube(name);
94-
Bob 0 5
9595
Alice 0 2
96+
Bob 0 5
9697
NULL 1 7
9798
""",
9899
since = "2.0.0")
@@ -120,13 +121,13 @@ case class Grouping(child: Expression) extends Expression with Unevaluable {
120121
examples = """
121122
Examples:
122123
> SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height);
123-
NULL 2 5 180.0
124124
Alice 0 2 165.0
125-
NULL 3 7 172.5
126-
NULL 2 2 165.0
127-
Bob 1 5 180.0
128125
Alice 1 2 165.0
126+
NULL 3 7 172.5
129127
Bob 0 5 180.0
128+
Bob 1 5 180.0
129+
NULL 2 2 165.0
130+
NULL 2 5 180.0
130131
""",
131132
note = """
132133
Input columns should match with grouping columns exactly, or empty (means all the grouping
@@ -139,7 +140,14 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une
139140
override lazy val references: AttributeSet =
140141
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
141142
override def children: Seq[Expression] = groupByExprs
142-
override def dataType: DataType = IntegerType
143+
override def dataType: DataType = GroupingID.dataType
143144
override def nullable: Boolean = false
144145
override def prettyName: String = "grouping_id"
145146
}
147+
148+
object GroupingID {
149+
150+
def dataType: DataType = {
151+
if (SQLConf.get.integerGroupingIdEnabled) IntegerType else LongType
152+
}
153+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,16 +608,17 @@ object Expand {
608608
*/
609609
private def buildBitmask(
610610
groupingSetAttrs: Seq[Attribute],
611-
attrMap: Map[Attribute, Int]): Int = {
611+
attrMap: Map[Attribute, Int]): Long = {
612612
val numAttributes = attrMap.size
613-
val mask = (1 << numAttributes) - 1
613+
assert(numAttributes <= GroupingID.dataType.defaultSize * 8)
614+
val mask = if (numAttributes != 64) (1L << numAttributes) - 1 else 0xFFFFFFFFFFFFFFFFL
614615
// Calculate the attrbute masks of selected grouping set. For example, if we have GroupBy
615616
// attributes (a, b, c, d), grouping set (a, c) will produce the following sequence:
616617
// (15, 7, 13), whose binary form is (1111, 0111, 1101)
617618
val masks = (mask +: groupingSetAttrs.map(attrMap).map(index =>
618619
// 0 means that the column at the given index is a grouping column, 1 means it is not,
619620
// so we unset the bit in bitmap.
620-
~(1 << (numAttributes - 1 - index))
621+
~(1L << (numAttributes - 1 - index))
621622
))
622623
// Reduce masks to generate an bitmask for the selected grouping set.
623624
masks.reduce(_ & _)
@@ -657,7 +658,11 @@ object Expand {
657658
attr
658659
}
659660
// groupingId is the last output, here we use the bit mask as the concrete value for it.
660-
} :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType)
661+
} :+ {
662+
val bitMask = buildBitmask(groupingSetAttrs, attrMap)
663+
val dataType = GroupingID.dataType
664+
Literal.create(if (dataType.sameType(IntegerType)) bitMask.toInt else bitMask, dataType)
665+
}
661666

662667
if (hasDuplicateGroupingSets) {
663668
// If `groupingSetsAttrs` has duplicate entries (e.g., GROUPING SETS ((key), (key))),

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,13 @@ object SQLConf {
24962496
.booleanConf
24972497
.createWithDefault(false)
24982498

2499+
val LEGACY_INTEGER_GROUPING_ID =
2500+
buildConf("spark.sql.legacy.integerGroupingId")
2501+
.internal()
2502+
.doc("When true, grouping_id() returns int values instead of long values.")
2503+
.booleanConf
2504+
.createWithDefault(false)
2505+
24992506
/**
25002507
* Holds information about keys that have been deprecated.
25012508
*
@@ -3072,6 +3079,8 @@ class SQLConf extends Serializable with Logging {
30723079

30733080
def csvFilterPushDown: Boolean = getConf(CSV_FILTER_PUSHDOWN_ENABLED)
30743081

3082+
def integerGroupingIdEnabled: Boolean = getConf(SQLConf.LEGACY_INTEGER_GROUPING_ID)
3083+
30753084
/** ********************** SQLConf functionality methods ************ */
30763085

30773086
/** Set Spark SQL configuration properties. */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
3333
lazy val unresolved_a = UnresolvedAttribute("a")
3434
lazy val unresolved_b = UnresolvedAttribute("b")
3535
lazy val unresolved_c = UnresolvedAttribute("c")
36-
lazy val gid = 'spark_grouping_id.int.withNullability(false)
37-
lazy val hive_gid = 'grouping__id.int.withNullability(false)
38-
lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType, Option(TimeZone.getDefault().getID))
36+
lazy val gid = 'spark_grouping_id.long.withNullability(false)
37+
lazy val hive_gid = 'grouping__id.long.withNullability(false)
38+
lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1L, ByteType, Option(TimeZone.getDefault().getID))
3939
lazy val nulInt = Literal(null, IntegerType)
4040
lazy val nulStr = Literal(null, StringType)
4141
lazy val r1 = LocalRelation(a, b, c)
@@ -72,7 +72,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
7272
Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
7373
val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
7474
Expand(
75-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
75+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)),
7676
Seq(a, b, c, a, b, gid),
7777
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
7878
checkAnalysis(originalPlan, expected)
@@ -98,7 +98,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
9898
Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
9999
val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
100100
Expand(
101-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
101+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)),
102102
Seq(a, b, c, a, b, gid),
103103
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
104104
checkAnalysis(originalPlan, expected)
@@ -125,16 +125,16 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
125125
Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)
126126
val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
127127
Expand(
128-
Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
129-
Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
128+
Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L),
129+
Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)),
130130
Seq(a, b, c, a, b, gid),
131131
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
132132
checkAnalysis(originalPlan, expected)
133133

134134
val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
135135
val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
136136
Expand(
137-
Seq(Seq(a, b, c, 0)),
137+
Seq(Seq(a, b, c, 0L)),
138138
Seq(a, b, c, gid),
139139
Project(Seq(a, b, c), r1)))
140140
checkAnalysis(originalPlan2, expected2)
@@ -145,15 +145,15 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
145145
Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)
146146
val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
147147
Expand(
148-
Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
148+
Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)),
149149
Seq(a, b, c, a, b, gid),
150150
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
151151
checkAnalysis(originalPlan, expected)
152152

153153
val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
154154
val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
155155
Expand(
156-
Seq(Seq(a, b, c, 0)),
156+
Seq(Seq(a, b, c, 0L)),
157157
Seq(a, b, c, gid),
158158
Project(Seq(a, b, c), r1)))
159159
checkAnalysis(originalPlan2, expected2)
@@ -168,7 +168,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
168168
val expected = Aggregate(Seq(a, b, gid),
169169
Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
170170
Expand(
171-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
171+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)),
172172
Seq(a, b, c, a, b, gid),
173173
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
174174
checkAnalysis(originalPlan, expected)
@@ -180,8 +180,8 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
180180
val expected2 = Aggregate(Seq(a, b, gid),
181181
Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
182182
Expand(
183-
Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
184-
Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
183+
Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L),
184+
Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)),
185185
Seq(a, b, c, a, b, gid),
186186
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
187187
checkAnalysis(originalPlan2, expected2)
@@ -193,7 +193,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
193193
val expected3 = Aggregate(Seq(a, b, gid),
194194
Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
195195
Expand(
196-
Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
196+
Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)),
197197
Seq(a, b, c, a, b, gid),
198198
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
199199
checkAnalysis(originalPlan3, expected3)
@@ -208,7 +208,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
208208
val expected = Aggregate(Seq(a, b, gid),
209209
Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
210210
Expand(
211-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
211+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)),
212212
Seq(a, b, c, a, b, gid),
213213
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
214214
checkAnalysis(originalPlan, expected)
@@ -220,8 +220,8 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
220220
val expected2 = Aggregate(Seq(a, b, gid),
221221
Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
222222
Expand(
223-
Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
224-
Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
223+
Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L),
224+
Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)),
225225
Seq(a, b, c, a, b, gid),
226226
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
227227
checkAnalysis(originalPlan2, expected2)
@@ -233,7 +233,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
233233
val expected3 = Aggregate(Seq(a, b, gid),
234234
Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
235235
Expand(
236-
Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
236+
Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)),
237237
Seq(a, b, c, a, b, gid),
238238
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
239239
checkAnalysis(originalPlan3, expected3)
@@ -249,7 +249,8 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
249249
Aggregate(Seq(a, b, gid),
250250
Seq(a, b, gid),
251251
Expand(
252-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
252+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
253+
Seq(a, b, c, a, b, 0L)),
253254
Seq(a, b, c, a, b, gid),
254255
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
255256
checkAnalysis(originalPlan, expected)
@@ -260,14 +261,15 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
260261
Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
261262

262263
// Filter with GroupingID
263-
val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1,
264+
val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1L,
264265
GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
265266
Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
266-
val expected3 = Project(Seq(a, b), Filter(gid === 1,
267+
val expected3 = Project(Seq(a, b), Filter(gid === 1L,
267268
Aggregate(Seq(a, b, gid),
268269
Seq(a, b, gid),
269270
Expand(
270-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
271+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
272+
Seq(a, b, c, a, b, 0L)),
271273
Seq(a, b, c, a, b, gid),
272274
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
273275
checkAnalysis(originalPlan3, expected3)
@@ -289,7 +291,8 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
289291
Aggregate(Seq(a, b, gid),
290292
Seq(a, b, grouping_a.as("aggOrder")),
291293
Expand(
292-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
294+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
295+
Seq(a, b, c, a, b, 0L)),
293296
Seq(a, b, c, a, b, gid),
294297
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
295298
checkAnalysis(originalPlan, expected)
@@ -305,11 +308,12 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
305308
GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
306309
Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
307310
val expected3 = Project(Seq(a, b), Sort(
308-
Seq(SortOrder('aggOrder.int.withNullability(false), Ascending)), true,
311+
Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true,
309312
Aggregate(Seq(a, b, gid),
310313
Seq(a, b, gid.as("aggOrder")),
311314
Expand(
312-
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
315+
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
316+
Seq(a, b, c, a, b, 0L)),
313317
Seq(a, b, c, a, b, gid),
314318
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
315319
checkAnalysis(originalPlan3, expected3)

0 commit comments

Comments
 (0)