From 82e5a7683de5293529f07c23f39004630175dc96 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 2 Jul 2020 13:54:24 +0900 Subject: [PATCH 1/5] [SPARK-32148][SS] Fix stream-stream join issue on missing to copy reused unsafe row --- .../StreamingSymmetricHashJoinExec.scala | 4 ++ .../state/SymmetricHashJoinStateManager.scala | 17 ++++++- .../sql/streaming/StreamingJoinSuite.scala | 44 +++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index dc5fc2e43143d..3d071df493cec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -295,6 +295,10 @@ case class StreamingSymmetricHashJoinExec( postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) } } + + // NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of + // elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update + // the match flag which the logic for outer join is relying on. val removedRowIter = leftSideJoiner.removeOldState() val outerOutputIter = removedRowIter.filterNot { kv => stateFormatVersion match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 1a0a43c083879..b4d2b90ba2113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -451,10 +451,25 @@ class SymmetricHashJoinStateManager( } private trait KeyWithIndexToValueRowConverter { + /** Defines the schema of the value row (the value side of K-V in state store). */ def valueAttributes: Seq[Attribute] + /** + * Convert the value row to (actual value, match) pair. + * + * NOTE: implementations should ensure the result row is NOT reused during execution, as + * caller may use the value to store without copy(). + */ def convertValue(value: UnsafeRow): ValueAndMatchPair + /** + * Build the value row from (actual value, match) pair. This is expected to be called just + * before storing to the state store. + * + * NOTE: depending on the implementation, the result row "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic doesn't + * affect by such behavior. Call copy() against the result row if needed. + */ def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow } @@ -493,7 +508,7 @@ class SymmetricHashJoinStateManager( override def convertValue(value: UnsafeRow): ValueAndMatchPair = { if (value != null) { - ValueAndMatchPair(valueRowGenerator(value), + ValueAndMatchPair(valueRowGenerator(value).copy(), value.getBoolean(indexOrdinalInValueWithMatchedRow)) } else { null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index caca749f9dd1e..f60d6f649efcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.sql.Timestamp import java.util.{Locale, UUID} import scala.util.Random @@ -996,4 +997,47 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with ) } } + + test("SPARK-32148 stream-stream join regression on Spark 3.0.0") { + val input1 = MemoryStream[(Timestamp, String, String)] + val df1 = input1.toDF + .selectExpr("_1 as eventTime", "_2 as id", "_3 as comment") + .withWatermark(s"eventTime", "2 minutes") + + val input2 = MemoryStream[(Timestamp, String, String)] + val df2 = input2.toDF + .selectExpr("_1 as eventTime", "_2 as id", "_3 as name") + .withWatermark(s"eventTime", "4 minutes") + + val joined = df1.as("left") + .join(df2.as("right"), + expr(s""" + |left.id = right.id AND left.eventTime BETWEEN + | right.eventTime - INTERVAL 30 seconds AND + | right.eventTime + INTERVAL 30 seconds + """.stripMargin), + joinType = "leftOuter") + + val inputDataForInput1 = Seq( + (Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner"), + (Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A"), + (Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B")) + + val inputDataForInput2 = Seq( + (Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"), + (Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B"), + (Timestamp.valueOf("2020-01-02 02:00:00"), "abc", "C")) + + val expectedOutput = Seq( + (Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner", null, null, null), + (Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A", + Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"), + (Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B", + Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B")) + + testStream(joined)( + MultiAddData((input1, inputDataForInput1), (input2, inputDataForInput2)), + CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*) + ) + } } From be342583629147c537cda14ed1708c3653000b3f Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Sat, 4 Jul 2020 10:54:09 +0900 Subject: [PATCH 2/5] Changed to performance-wise approach, with adding WARN comments to all callers --- .../state/SymmetricHashJoinStateManager.scala | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index b4d2b90ba2113..196c4b4071ab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -83,7 +83,13 @@ class SymmetricHashJoinStateManager( ===================================================== */ - /** Get all the values of a key */ + /** + * Get all the values of a key. + * + * NOTE: the returned row "may" be reused during execution (to avoid initialization of object), + * so the caller should ensure that the logic doesn't affect by such behavior. Call copy() + * against the row if needed. + */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) keyWithIndexToValue.getAll(key, numValues).map(_.value) @@ -99,6 +105,10 @@ class SymmetricHashJoinStateManager( /** * Get all the matched values for given join condition, with marking matched. * This method is designed to mark joined rows properly without exposing internal index of row. + * + * NOTE: the "value" field in JoinedRow "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic + * doesn't affect by such behavior. Call copy() against these rows if needed. */ def getJoinedRows( key: UnsafeRow, @@ -250,7 +260,7 @@ class SymmetricHashJoinStateManager( } override def getNext(): KeyToValuePair = { - val currentValue = findNextValueForIndex() + var currentValue = findNextValueForIndex() // If there's no value, clean up and finish. There aren't any more available. if (currentValue == null) { @@ -259,6 +269,9 @@ class SymmetricHashJoinStateManager( return null } + // Make a copy on value row, as below cleanup logic may update the value row silently. + currentValue = currentValue.copy(value = currentValue.value.copy()) + // The backing store is arraylike - we as the caller are responsible for filling back in // any hole. So we swap the last element into the hole and decrement numValues to shorten. // clean @@ -457,8 +470,9 @@ class SymmetricHashJoinStateManager( /** * Convert the value row to (actual value, match) pair. * - * NOTE: implementations should ensure the result row is NOT reused during execution, as - * caller may use the value to store without copy(). + * NOTE: depending on the implementation, the row (actual value) in the pair "may" be reused + * during execution (to avoid initialization of object), so the caller should ensure that + * the logic doesn't affect by such behavior. Call copy() against the row if needed. */ def convertValue(value: UnsafeRow): ValueAndMatchPair @@ -508,7 +522,7 @@ class SymmetricHashJoinStateManager( override def convertValue(value: UnsafeRow): ValueAndMatchPair = { if (value != null) { - ValueAndMatchPair(valueRowGenerator(value).copy(), + ValueAndMatchPair(valueRowGenerator(value), value.getBoolean(indexOrdinalInValueWithMatchedRow)) } else { null @@ -545,13 +559,21 @@ class SymmetricHashJoinStateManager( protected val stateStore = getStateStore(keyWithIndexSchema, valueRowConverter.valueAttributes.toStructType) + /** + * NOTE: the "value" field in return value "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic + * doesn't affect by such behavior. Call copy() against the row if needed. + */ def get(key: UnsafeRow, valueIndex: Long): ValueAndMatchPair = { valueRowConverter.convertValue(stateStore.get(keyWithIndexRow(key, valueIndex))) } /** - * Get all values and indices for the provided key. - * Should not return null. + * Get all values and indices for the provided key. Should not return null. + * + * NOTE: the "key" and "value" field in return value "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic + * doesn't affect by such behavior. Call copy() against these rows if needed. */ def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { val keyWithIndexAndValue = new KeyWithIndexAndValue() @@ -598,6 +620,11 @@ class SymmetricHashJoinStateManager( } } + /** + * NOTE: the "key" and "value" field in return value "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic + * doesn't affect by such behavior. Call copy() against these rows if needed. + */ def iterator: Iterator[KeyWithIndexAndValue] = { val keyWithIndexAndValue = new KeyWithIndexAndValue() stateStore.getRange(None, None).map { pair => From e2201efc4b1f1f32a7a869935d25d94ade55d2d4 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 8 Jul 2020 06:02:34 +0900 Subject: [PATCH 3/5] Revert "Changed to performance-wise approach, with adding WARN comments to all callers" This reverts commit be342583629147c537cda14ed1708c3653000b3f. --- .../state/SymmetricHashJoinStateManager.scala | 41 ++++--------------- 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 196c4b4071ab9..b4d2b90ba2113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -83,13 +83,7 @@ class SymmetricHashJoinStateManager( ===================================================== */ - /** - * Get all the values of a key. - * - * NOTE: the returned row "may" be reused during execution (to avoid initialization of object), - * so the caller should ensure that the logic doesn't affect by such behavior. Call copy() - * against the row if needed. - */ + /** Get all the values of a key */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) keyWithIndexToValue.getAll(key, numValues).map(_.value) @@ -105,10 +99,6 @@ class SymmetricHashJoinStateManager( /** * Get all the matched values for given join condition, with marking matched. * This method is designed to mark joined rows properly without exposing internal index of row. - * - * NOTE: the "value" field in JoinedRow "may" be reused during execution - * (to avoid initialization of object), so the caller should ensure that the logic - * doesn't affect by such behavior. Call copy() against these rows if needed. */ def getJoinedRows( key: UnsafeRow, @@ -260,7 +250,7 @@ class SymmetricHashJoinStateManager( } override def getNext(): KeyToValuePair = { - var currentValue = findNextValueForIndex() + val currentValue = findNextValueForIndex() // If there's no value, clean up and finish. There aren't any more available. if (currentValue == null) { @@ -269,9 +259,6 @@ class SymmetricHashJoinStateManager( return null } - // Make a copy on value row, as below cleanup logic may update the value row silently. - currentValue = currentValue.copy(value = currentValue.value.copy()) - // The backing store is arraylike - we as the caller are responsible for filling back in // any hole. So we swap the last element into the hole and decrement numValues to shorten. // clean @@ -470,9 +457,8 @@ class SymmetricHashJoinStateManager( /** * Convert the value row to (actual value, match) pair. * - * NOTE: depending on the implementation, the row (actual value) in the pair "may" be reused - * during execution (to avoid initialization of object), so the caller should ensure that - * the logic doesn't affect by such behavior. Call copy() against the row if needed. + * NOTE: implementations should ensure the result row is NOT reused during execution, as + * caller may use the value to store without copy(). */ def convertValue(value: UnsafeRow): ValueAndMatchPair @@ -522,7 +508,7 @@ class SymmetricHashJoinStateManager( override def convertValue(value: UnsafeRow): ValueAndMatchPair = { if (value != null) { - ValueAndMatchPair(valueRowGenerator(value), + ValueAndMatchPair(valueRowGenerator(value).copy(), value.getBoolean(indexOrdinalInValueWithMatchedRow)) } else { null @@ -559,21 +545,13 @@ class SymmetricHashJoinStateManager( protected val stateStore = getStateStore(keyWithIndexSchema, valueRowConverter.valueAttributes.toStructType) - /** - * NOTE: the "value" field in return value "may" be reused during execution - * (to avoid initialization of object), so the caller should ensure that the logic - * doesn't affect by such behavior. Call copy() against the row if needed. - */ def get(key: UnsafeRow, valueIndex: Long): ValueAndMatchPair = { valueRowConverter.convertValue(stateStore.get(keyWithIndexRow(key, valueIndex))) } /** - * Get all values and indices for the provided key. Should not return null. - * - * NOTE: the "key" and "value" field in return value "may" be reused during execution - * (to avoid initialization of object), so the caller should ensure that the logic - * doesn't affect by such behavior. Call copy() against these rows if needed. + * Get all values and indices for the provided key. + * Should not return null. */ def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { val keyWithIndexAndValue = new KeyWithIndexAndValue() @@ -620,11 +598,6 @@ class SymmetricHashJoinStateManager( } } - /** - * NOTE: the "key" and "value" field in return value "may" be reused during execution - * (to avoid initialization of object), so the caller should ensure that the logic - * doesn't affect by such behavior. Call copy() against these rows if needed. - */ def iterator: Iterator[KeyWithIndexAndValue] = { val keyWithIndexAndValue = new KeyWithIndexAndValue() stateStore.getRange(None, None).map { pair => From 1c011abf3c6e18ecdee608f76926e1b31162a5cc Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 8 Jul 2020 17:10:07 +0900 Subject: [PATCH 4/5] Update the code comment --- .../streaming/state/SymmetricHashJoinStateManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index b4d2b90ba2113..1a5b50dcc7901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -457,8 +457,8 @@ class SymmetricHashJoinStateManager( /** * Convert the value row to (actual value, match) pair. * - * NOTE: implementations should ensure the result row is NOT reused during execution, as - * caller may use the value to store without copy(). + * NOTE: implementations should ensure the result row is NOT reused during execution, so + * that caller can safely read the value in any time. */ def convertValue(value: UnsafeRow): ValueAndMatchPair From fb63d7ed604e9fc31a960e23e9b00298d5aa8760 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 9 Jul 2020 07:27:07 +0900 Subject: [PATCH 5/5] Review comment --- .../spark/sql/streaming/StreamingJoinSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index f60d6f649efcb..b182727408bbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1011,11 +1011,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with val joined = df1.as("left") .join(df2.as("right"), - expr(s""" - |left.id = right.id AND left.eventTime BETWEEN - | right.eventTime - INTERVAL 30 seconds AND - | right.eventTime + INTERVAL 30 seconds - """.stripMargin), + expr(""" + |left.id = right.id AND left.eventTime BETWEEN + | right.eventTime - INTERVAL 30 seconds AND + | right.eventTime + INTERVAL 30 seconds + """.stripMargin), joinType = "leftOuter") val inputDataForInput1 = Seq(