Skip to content

Commit 9be67c8

Browse files
committed
1 parent 075ce49 commit 9be67c8

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,19 @@ class CoGroupedIterator(
3838
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
3939
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
4040

41-
override def hasNext: Boolean = left.hasNext || right.hasNext
42-
43-
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
44-
if (currentLeftData.eq(null) && left.hasNext) {
41+
override def hasNext: Boolean = {
42+
if (currentLeftData == null && left.hasNext) {
4543
currentLeftData = left.next()
4644
}
47-
if (currentRightData.eq(null) && right.hasNext) {
45+
if (currentRightData == null && right.hasNext) {
4846
currentRightData = right.next()
4947
}
5048

51-
assert(currentLeftData.ne(null) || currentRightData.ne(null))
49+
currentLeftData != null || currentRightData != null
50+
}
51+
52+
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
53+
assert(hasNext)
5254

5355
if (currentLeftData.eq(null)) {
5456
// left is null, right is not null, consume the right data.

sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
4848
Nil
4949
)
5050
}
51+
52+
test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
53+
val leftInput = Seq(create_row(2, "a")).iterator
54+
val rightInput = Seq(create_row(1, 2L)).iterator
55+
val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
56+
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
57+
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
58+
59+
val result = cogrouped.map {
60+
case (key, leftData, rightData) =>
61+
assert(key.numFields == 1)
62+
(key.getInt(0), leftData.toSeq, rightData.toSeq)
63+
}.toSeq
64+
65+
assert(result ==
66+
(1,
67+
Seq.empty,
68+
Seq(create_row(1, 2L))) ::
69+
(2,
70+
Seq(create_row(2, "a")),
71+
Seq.empty) ::
72+
Nil
73+
)
74+
}
5175
}

0 commit comments

Comments
 (0)