Skip to content

Commit a8cc6b5

Browse files
committed
fix a bug in GroupedIterator and create unit test for it
1 parent b960a89 commit a8cc6b5

File tree

2 files changed

+121
-37
lines changed

2 files changed

+121
-37
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ object GroupedIterator {
2727
keyExpressions: Seq[Expression],
2828
inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
2929
if (input.hasNext) {
30-
new GroupedIterator(input, keyExpressions, inputSchema)
30+
new GroupedIterator(input.buffered, keyExpressions, inputSchema)
3131
} else {
3232
Iterator.empty
3333
}
@@ -64,7 +64,7 @@ object GroupedIterator {
6464
* @param inputSchema The schema of the rows in the `input` iterator.
6565
*/
6666
class GroupedIterator private(
67-
input: Iterator[InternalRow],
67+
input: BufferedIterator[InternalRow],
6868
groupingExpressions: Seq[Expression],
6969
inputSchema: Seq[Attribute])
7070
extends Iterator[(InternalRow, Iterator[InternalRow])] {
@@ -83,11 +83,12 @@ class GroupedIterator private(
8383

8484
/** Holds a copy of an input row that is in the current group. */
8585
var currentGroup = currentRow.copy()
86-
var currentIterator: Iterator[InternalRow] = null
86+
8787
assert(keyOrdering.compare(currentGroup, currentRow) == 0)
88+
var currentIterator = createGroupValuesIterator()
8889

8990
// Return true if we already have the next iterator or fetching a new iterator is successful.
90-
def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
91+
def hasNext: Boolean = currentIterator.ne(null) || fetchNextGroupIterator
9192

9293
def next(): (InternalRow, Iterator[InternalRow]) = {
9394
assert(hasNext) // Ensure we have fetched the next iterator.
@@ -96,46 +97,64 @@ class GroupedIterator private(
9697
ret
9798
}
9899

99-
def fetchNextGroupIterator(): Boolean = {
100-
if (currentRow != null || input.hasNext) {
101-
val inputIterator = new Iterator[InternalRow] {
102-
// Return true if we have a row and it is in the current group, or if fetching a new row is
103-
// successful.
104-
def hasNext = {
105-
(currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
106-
fetchNextRowInGroup()
107-
}
100+
private def fetchNextGroupIterator(): Boolean = {
101+
assert(currentIterator eq null)
102+
103+
if (currentRow.eq(null) && input.hasNext) {
104+
currentRow = input.next()
105+
}
106+
107+
if (currentRow eq null) {
108+
// These is no data left, return false.
109+
false
110+
} else {
111+
// Skip to next group.
112+
while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
113+
currentRow = input.next()
114+
}
115+
116+
if (keyOrdering.compare(currentGroup, currentRow) == 0) {
117+
// These is no more group. return false.
118+
false
119+
} else {
120+
// Now the `currentRow` is the first row of next group.
121+
currentGroup = currentRow.copy()
122+
currentIterator = createGroupValuesIterator()
123+
true
124+
}
125+
}
126+
}
127+
128+
private def createGroupValuesIterator(): Iterator[InternalRow] = {
129+
new Iterator[InternalRow] {
130+
def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()
131+
132+
def next(): InternalRow = {
133+
assert(hasNext)
134+
val res = currentRow
135+
currentRow = null
136+
res
137+
}
108138

109-
def fetchNextRowInGroup(): Boolean = {
110-
if (currentRow != null || input.hasNext) {
139+
private def fetchNextRowInGroup(): Boolean = {
140+
assert(currentRow eq null)
141+
142+
if (input.hasNext) {
143+
// The inner iterator should NOT consume the input into next group, here we use `head` to
144+
// peek the next input, to see if we should continue to process it.
145+
if (keyOrdering.compare(currentGroup, input.head) == 0) {
146+
// Next input is in the current group. Continue the inner iterator.
111147
currentRow = input.next()
112-
if (keyOrdering.compare(currentGroup, currentRow) == 0) {
113-
// The row is in the current group. Continue the inner iterator.
114-
true
115-
} else {
116-
// We got a row, but its not in the right group. End this inner iterator and prepare
117-
// for the next group.
118-
currentIterator = null
119-
currentGroup = currentRow.copy()
120-
false
121-
}
148+
true
122149
} else {
123-
// There is no more input so we are done.
150+
// Next input is not in the right group. End this inner iterator.
124151
false
125152
}
126-
}
127-
128-
def next(): InternalRow = {
129-
assert(hasNext) // Ensure we have fetched the next row.
130-
val res = currentRow
131-
currentRow = null
132-
res
153+
} else {
154+
// There is no more data, return false.
155+
false
133156
}
134157
}
135-
currentIterator = inputIterator
136-
true
137-
} else {
138-
false
139158
}
140159
}
141160
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package org.apache.spark.sql.execution
2+
3+
import org.apache.spark.SparkFunSuite
4+
import org.apache.spark.sql.Row
5+
import org.apache.spark.sql.catalyst.dsl.expressions._
6+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
7+
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}
8+
9+
class GroupedIteratorSuite extends SparkFunSuite {
10+
11+
test("basic") {
12+
val schema = new StructType().add("i", IntegerType).add("s", StringType)
13+
val encoder = RowEncoder(schema)
14+
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
15+
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
16+
Seq('i.int.at(0)), schema.toAttributes)
17+
18+
val result = grouped.map {
19+
case (key, data) =>
20+
assert(key.numFields == 1)
21+
key.getInt(0) -> data.map(encoder.fromRow).toSeq
22+
}.toSeq
23+
24+
assert(result ==
25+
1 -> Seq(input(0), input(1)) ::
26+
2 -> Seq(input(2)) :: Nil)
27+
}
28+
29+
test("group by 2 columns") {
30+
val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
31+
val encoder = RowEncoder(schema)
32+
33+
val input = Seq(
34+
Row(1, 2L, "a"),
35+
Row(1, 2L, "b"),
36+
Row(1, 3L, "c"),
37+
Row(2, 1L, "d"),
38+
Row(3, 2L, "e"))
39+
40+
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
41+
Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)
42+
43+
val result = grouped.map {
44+
case (key, data) =>
45+
assert(key.numFields == 2)
46+
(key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
47+
}.toSeq
48+
49+
assert(result ==
50+
(1, 2L, Seq(input(0), input(1))) ::
51+
(1, 3L, Seq(input(2))) ::
52+
(2, 1L, Seq(input(3))) ::
53+
(3, 2L, Seq(input(4))) :: Nil)
54+
}
55+
56+
test("do nothing to the value iterator") {
57+
val schema = new StructType().add("i", IntegerType).add("s", StringType)
58+
val encoder = RowEncoder(schema)
59+
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
60+
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
61+
Seq('i.int.at(0)), schema.toAttributes)
62+
63+
assert(grouped.length == 2)
64+
}
65+
}

0 commit comments

Comments
 (0)