Skip to content

Commit 70b169c

Browse files
committed
Remove groupOrdering.
1 parent 4721936 commit 70b169c

File tree

2 files changed

+27
-26
lines changed

2 files changed

+27
-26
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,6 @@ case class Aggregate2Sort(
149149
// This is used to project expressions for the grouping expressions.
150150
protected val groupGenerator =
151151
newMutableProjection(groupingExpressions, child.output)()
152-
// A ordering used to compare if a new row belongs to the current group
153-
// or a new group.
154-
private val groupOrdering: Ordering[InternalRow] = {
155-
val groupingAttributes = groupingExpressions.map(_.toAttribute)
156-
newOrdering(
157-
groupingAttributes.map(expr => SortOrder(expr, Ascending)),
158-
groupingAttributes)
159-
}
160152
// The partition key of the current partition.
161153
private var currentGroupingKey: InternalRow = _
162154
// The partition key of next partition.
@@ -182,18 +174,18 @@ case class Aggregate2Sort(
182174
// aggregate function, the size of the buffer matches the number of values in the
183175
// input rows. To simplify the code for code-gen, we need create some dummy
184176
// attributes and expressions for these grouping expressions.
185-
val offsetAttributes = {
177+
private val offsetAttributes = {
186178
if (partialAggregation) {
187179
Nil
188180
} else {
189181
Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
190182
}
191183
}
192-
val offsetExpressions =
184+
private val offsetExpressions =
193185
if (partialAggregation) Nil else Seq.fill(groupingExpressions.length)(NoOp)
194186

195187
// This projection is used to initialize buffer values for all AlgebraicAggregates.
196-
val algebraicInitialProjection = {
188+
private val algebraicInitialProjection = {
197189
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
198190
case ae: AlgebraicAggregate => ae.initialValues
199191
case agg: AggregateFunction2 => NoOp :: Nil
@@ -202,7 +194,7 @@ case class Aggregate2Sort(
202194
}
203195

204196
// This projection is used to update buffer values for all AlgebraicAggregates.
205-
lazy val algebraicUpdateProjection = {
197+
private lazy val algebraicUpdateProjection = {
206198
val bufferSchema = aggregateFunctions.flatMap {
207199
case ae: AlgebraicAggregate => ae.bufferAttributes
208200
case agg: AggregateFunction2 => agg.bufferAttributes
@@ -215,7 +207,7 @@ case class Aggregate2Sort(
215207
}
216208

217209
// This projection is used to merge buffer values for all AlgebraicAggregates.
218-
lazy val algebraicMergeProjection = {
210+
private lazy val algebraicMergeProjection = {
219211
val bufferSchemata =
220212
offsetAttributes ++ aggregateFunctions.flatMap {
221213
case ae: AlgebraicAggregate => ae.bufferAttributes
@@ -233,7 +225,7 @@ case class Aggregate2Sort(
233225
}
234226

235227
// This projection is used to evaluate all AlgebraicAggregates.
236-
lazy val algebraicEvalProjection = {
228+
private lazy val algebraicEvalProjection = {
237229
val bufferSchemata =
238230
offsetAttributes ++ aggregateFunctions.flatMap {
239231
case ae: AlgebraicAggregate => ae.bufferAttributes
@@ -313,8 +305,9 @@ case class Aggregate2Sort(
313305
// For the below compare method, we do not need to make a copy of groupingKey.
314306
val groupingKey = groupGenerator(currentRow)
315307
// Check if the current row belongs the current input row.
316-
val comparing = groupOrdering.compare(currentGroupingKey, groupingKey)
317-
if (comparing == 0) {
308+
currentGroupingKey.equals(groupingKey)
309+
310+
if (currentGroupingKey == groupingKey) {
318311
processRow(currentRow)
319312
} else {
320313
// We find a new group.

sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,19 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
3232
override def beforeAll(): Unit = {
3333
originalUseAggregate2 = ctx.conf.useSqlAggregate2
3434
ctx.sql("set spark.sql.useAggregate2=true")
35-
val data = Seq[(Int, Integer)](
35+
val data = Seq[(Integer, Integer)](
3636
(1, 10),
37+
(null, -60),
3738
(1, 20),
3839
(1, 30),
3940
(2, 0),
41+
(null, -10),
4042
(2, -1),
4143
(2, null),
4244
(2, null),
45+
(null, 100),
4346
(3, null),
47+
(null, null),
4448
(3, null)).toDF("key", "value")
4549

4650
data.write.saveAsTable("agg2")
@@ -54,7 +58,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
5458
|FROM agg2
5559
|GROUP BY key
5660
""".stripMargin),
57-
Row(-0.5) :: Row(20.0) :: Row(null) :: Nil)
61+
Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil)
5862
}
5963

6064
test("test average2") {
@@ -79,7 +83,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
7983
|FROM agg2
8084
|GROUP BY key
8185
""".stripMargin),
82-
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil)
86+
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)
8387

8488
checkAnswer(
8589
ctx.sql(
@@ -88,7 +92,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
8892
|FROM agg2
8993
|GROUP BY key
9094
""".stripMargin),
91-
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Nil)
95+
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil)
9296

9397
checkAnswer(
9498
ctx.sql(
@@ -97,14 +101,14 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
97101
|FROM agg2
98102
|GROUP BY key + 10
99103
""".stripMargin),
100-
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Nil)
104+
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil)
101105

102106
checkAnswer(
103107
ctx.sql(
104108
"""
105109
|SELECT avg(value) FROM agg2
106110
""".stripMargin),
107-
Row(11.8) :: Nil)
111+
Row(11.125) :: Nil)
108112

109113
checkAnswer(
110114
ctx.sql(
@@ -137,14 +141,14 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
137141
|FROM agg2
138142
|GROUP BY key
139143
""".stripMargin),
140-
Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Nil)
144+
Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
141145

142146
checkAnswer(
143147
ctx.sql(
144148
"""
145149
|SELECT mydoublesum(cast(value as double)) FROM agg2
146150
""".stripMargin),
147-
Row(59.0) :: Nil)
151+
Row(89.0) :: Nil)
148152

149153
checkAnswer(
150154
ctx.sql(
@@ -163,7 +167,10 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
163167
|FROM agg2
164168
|GROUP BY key
165169
""".stripMargin),
166-
Row(60.0, 1, 20.0) :: Row(-1.0, 2, -0.5) :: Row(null, 3, null) :: Nil)
170+
Row(60.0, 1, 20.0) ::
171+
Row(-1.0, 2, -0.5) ::
172+
Row(null, 3, null) ::
173+
Row(30.0, null, 10.0) :: Nil)
167174

168175
checkAnswer(
169176
ctx.sql(
@@ -179,7 +186,8 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
179186
""".stripMargin),
180187
Row(64.5, 19.0, 1, 55.5, 20.0) ::
181188
Row(5.0, -2.5, 2, -7.0, -0.5) ::
182-
Row(null, null, 3, null, null) :: Nil)
189+
Row(null, null, 3, null, null) ::
190+
Row(null, null, null, null, 10.0) :: Nil)
183191
}
184192

185193
override def afterAll(): Unit = {

0 commit comments

Comments
 (0)