Skip to content

Commit 31ca780

Browse files
committed
Address part of review comments
1 parent 5f0988e commit 31ca780

File tree

3 files changed

+14
-23
lines changed

3 files changed

+14
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ case class StreamingSymmetricHashJoinExec(
299299
case 1 => matchesWithRightSideState(
300300
new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value))
301301
case 2 => kvAndMatched.matched
302-
case _ => throw new IllegalStateException("Incorrect state format version! " +
302+
case _ => throw new IllegalStateException("Unexpected state format version! " +
303303
s"version $stateFormatVersion")
304304
}
305305
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
@@ -318,7 +318,7 @@ case class StreamingSymmetricHashJoinExec(
318318
case 1 => matchesWithLeftSideState(
319319
new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value))
320320
case 2 => kvAndMatched.matched
321-
case _ => throw new IllegalStateException("Incorrect state format version! " +
321+
case _ => throw new IllegalStateException("Unexpected state format version! " +
322322
s"version $stateFormatVersion")
323323
}
324324
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
@@ -491,15 +491,7 @@ case class StreamingSymmetricHashJoinExec(
491491
thisRow: UnsafeRow,
492492
subIter: Iterator[JoinedRow])
493493
extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) {
494-
private var iteratorNotEmpty: Boolean = false
495-
496-
override def hasNext: Boolean = {
497-
val ret = super.hasNext
498-
if (ret && !iteratorNotEmpty) {
499-
iteratorNotEmpty = true
500-
}
501-
ret
502-
}
494+
private val iteratorNotEmpty: Boolean = super.hasNext
503495

504496
override def completion(): Unit = {
505497
val shouldAddToState = // add only if both removal predicates do not match

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class SymmetricHashJoinStateManager(
107107
val joinedRow = generateJoinedRow(keyIdxToValue.value)
108108
if (predicate(joinedRow)) {
109109
if (!keyIdxToValue.matched) {
110-
// only update when matched flag is false
111110
keyWithIndexToValue.put(key, keyIdxToValue.valueIndex, keyIdxToValue.value,
112111
matched = true)
113112
}
@@ -134,7 +133,7 @@ class SymmetricHashJoinStateManager(
134133
private val allKeyToNumValues = keyToNumValues.iterator
135134

136135
private var currentKeyToNumValue: KeyAndNumValues = null
137-
private var currentValues: Iterator[KeyWithIndexAndValueWithMatched] = null
136+
private var currentValues: Iterator[KeyWithIndexAndValue] = null
138137

139138
private def currentKey = currentKeyToNumValue.key
140139

@@ -412,7 +411,7 @@ class SymmetricHashJoinStateManager(
412411
* Helper class for representing data returned by [[KeyWithIndexToValueStore]].
413412
* Designed for object reuse.
414413
*/
415-
private class KeyWithIndexAndValueWithMatched(
414+
private class KeyWithIndexAndValue(
416415
var key: UnsafeRow = null,
417416
var valueIndex: Long = -1,
418417
var value: UnsafeRow = null,
@@ -520,11 +519,11 @@ class SymmetricHashJoinStateManager(
520519
* Get all values and indices for the provided key.
521520
* Should not return null.
522521
*/
523-
def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValueWithMatched] = {
524-
val keyWithIndexAndValueWithMatched = new KeyWithIndexAndValueWithMatched()
522+
def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = {
523+
val keyWithIndexAndValueWithMatched = new KeyWithIndexAndValue()
525524
var index = 0
526-
new NextIterator[KeyWithIndexAndValueWithMatched] {
527-
override protected def getNext(): KeyWithIndexAndValueWithMatched = {
525+
new NextIterator[KeyWithIndexAndValue] {
526+
override protected def getNext(): KeyWithIndexAndValue = {
528527
if (index >= numValues) {
529528
finished = true
530529
null
@@ -565,8 +564,8 @@ class SymmetricHashJoinStateManager(
565564
}
566565
}
567566

568-
def iterator: Iterator[KeyWithIndexAndValueWithMatched] = {
569-
val keyWithIndexAndValueWithMatched = new KeyWithIndexAndValueWithMatched()
567+
def iterator: Iterator[KeyWithIndexAndValue] = {
568+
val keyWithIndexAndValueWithMatched = new KeyWithIndexAndValue()
570569
stateStore.getRange(None, None).map { pair =>
571570
val (value, matched) = valueRowConverter.convertValue(pair.value)
572571
keyWithIndexAndValueWithMatched.withNew(

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,8 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
800800
expr("leftId = rightId AND leftTime >= rightTime AND " +
801801
"leftTime <= rightTime + interval 5 seconds"),
802802
joinType = "rightOuter")
803-
.select(col("rightId"), col("rightTime").cast("int"),
804-
col("leftId"), col("leftTime").cast("int"))
803+
.select(col("leftId"), col("leftTime").cast("int"),
804+
col("rightId"), col("rightTime").cast("int"))
805805

806806
// we can just flip left and right in the explanation of left outer query test
807807
// to assume the status of right outer query, hence skip explaining here
@@ -817,7 +817,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
817817
AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)),
818818
CheckNewAnswer(
819819
Row(12, 12L, 12, 12L), Row(14, 14L, 14, 14L),
820-
Row(1, 1L, null, null), Row(3, 3L, null, null)),
820+
Row(null, null, 1, 1L), Row(null, null, 3, 3L)),
821821
assertNumStateRows(15, 7)
822822
)
823823
}

0 commit comments

Comments
 (0)