From 9be67c8ae302c6596aaf34c68aa12ed8c56d058f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Oct 2015 12:06:07 +0800 Subject: [PATCH] SPARK-11393 --- .../sql/execution/CoGroupedIterator.scala | 14 ++++++----- .../execution/CoGroupedIteratorSuite.scala | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala index ce5827855e4aa..663bc904f39c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -38,17 +38,19 @@ class CoGroupedIterator( private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _ private var currentRightData: (InternalRow, Iterator[InternalRow]) = _ - override def hasNext: Boolean = left.hasNext || right.hasNext - - override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { - if (currentLeftData.eq(null) && left.hasNext) { + override def hasNext: Boolean = { + if (currentLeftData == null && left.hasNext) { currentLeftData = left.next() } - if (currentRightData.eq(null) && right.hasNext) { + if (currentRightData == null && right.hasNext) { currentRightData = right.next() } - assert(currentLeftData.ne(null) || currentRightData.ne(null)) + currentLeftData != null || currentRightData != null + } + + override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + assert(hasNext) if (currentLeftData.eq(null)) { // left is null, right is not null, consume the right data. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala index d1fe81947e9ea..4ff96e6574cac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { Nil ) } + + test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { + val leftInput = Seq(create_row(2, "a")).iterator + val rightInput = Seq(create_row(1, 2L)).iterator + val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + + val result = cogrouped.map { + case (key, leftData, rightData) => + assert(key.numFields == 1) + (key.getInt(0), leftData.toSeq, rightData.toSeq) + }.toSeq + + assert(result == + (1, + Seq.empty, + Seq(create_row(1, 2L))) :: + (2, + Seq(create_row(2, "a")), + Seq.empty) :: + Nil + ) + } }