diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index dc5fc2e43143d..3d071df493cec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -295,6 +295,10 @@ case class StreamingSymmetricHashJoinExec( postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) } } + + // NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of + // elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update + // the match flag which the logic for outer join is relying on. val removedRowIter = leftSideJoiner.removeOldState() val outerOutputIter = removedRowIter.filterNot { kv => stateFormatVersion match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 1a0a43c083879..1a5b50dcc7901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -451,10 +451,25 @@ class SymmetricHashJoinStateManager( } private trait KeyWithIndexToValueRowConverter { + /** Defines the schema of the value row (the value side of K-V in state store). */ def valueAttributes: Seq[Attribute] + /** + * Convert the value row to (actual value, match) pair. + * + * NOTE: implementations should ensure the result row is NOT reused during execution, so + * that caller can safely read the value in any time. + */ def convertValue(value: UnsafeRow): ValueAndMatchPair + /** + * Build the value row from (actual value, match) pair. This is expected to be called just + * before storing to the state store. + * + * NOTE: depending on the implementation, the result row "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic doesn't + * affect by such behavior. Call copy() against the result row if needed. + */ def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow } @@ -493,7 +508,7 @@ class SymmetricHashJoinStateManager( override def convertValue(value: UnsafeRow): ValueAndMatchPair = { if (value != null) { - ValueAndMatchPair(valueRowGenerator(value), + ValueAndMatchPair(valueRowGenerator(value).copy(), value.getBoolean(indexOrdinalInValueWithMatchedRow)) } else { null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index caca749f9dd1e..b182727408bbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.sql.Timestamp import java.util.{Locale, UUID} import scala.util.Random @@ -996,4 +997,47 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with ) } } + + test("SPARK-32148 stream-stream join regression on Spark 3.0.0") { + val input1 = MemoryStream[(Timestamp, String, String)] + val df1 = input1.toDF + .selectExpr("_1 as eventTime", "_2 as id", "_3 as comment") + .withWatermark(s"eventTime", "2 minutes") + + val input2 = MemoryStream[(Timestamp, String, String)] + val df2 = input2.toDF + .selectExpr("_1 as eventTime", "_2 as id", "_3 as name") + .withWatermark(s"eventTime", "4 minutes") + + val joined = df1.as("left") + .join(df2.as("right"), + expr(""" + |left.id = right.id AND left.eventTime BETWEEN + | right.eventTime - INTERVAL 30 seconds AND + | right.eventTime + INTERVAL 30 seconds + """.stripMargin), + joinType = "leftOuter") + + val inputDataForInput1 = Seq( + (Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner"), + (Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A"), + (Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B")) + + val inputDataForInput2 = Seq( + (Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"), + (Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B"), + (Timestamp.valueOf("2020-01-02 02:00:00"), "abc", "C")) + + val expectedOutput = Seq( + (Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner", null, null, null), + (Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A", + Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"), + (Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B", + Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B")) + + testStream(joined)( + MultiAddData((input1, inputDataForInput1), (input2, inputDataForInput2)), + CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*) + ) + } }