diff --git a/docs/ss-migration-guide.md b/docs/ss-migration-guide.md index b0fd8a8325dff..db8fdff8b2ac4 100644 --- a/docs/ss-migration-guide.md +++ b/docs/ss-migration-guide.md @@ -30,3 +30,4 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. - In Spark 3.0, Structured Streaming forces the source schema into nullable when file-based datasources such as text, json, csv, parquet and orc are used via `spark.readStream(...)`. Previously, it respected the nullability in source schema; however, it caused issues tricky to debug with NPE. To restore the previous behavior, set `spark.sql.streaming.fileSource.schema.forceNullable` to `false`. +- Spark 3.0 fixes the correctness issue on Stream-stream outer join, which changes the schema of state. (SPARK-26154 for more details) Spark 3.0 will fail the query if you start your query from checkpoint constructed from Spark 2.x which uses stream-stream outer join. Please discard the checkpoint and replay previous inputs to recalculate outputs. \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index eebf4b6dfd396..8a51ea4c1f713 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1069,6 +1069,16 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_JOIN_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.join.stateFormatVersion") + .internal() + .doc("State format version used by streaming join operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6e43c9b8bd80b..84fd20d8df0f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -474,8 +474,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if left.isStreaming && right.isStreaming => - new StreamingSymmetricHashJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + val stateVersion = conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) + new StreamingSymmetricHashJoinExec(leftKeys, rightKeys, joinType, condition, + stateVersion, planLater(left), planLater(right)) :: Nil case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 357c049aa18fa..1c59464268444 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager, SymmetricHashJoinStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} @@ -91,7 +91,8 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, + STREAMING_JOIN_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -108,7 +109,9 @@ object OffsetSeqMetadata extends Logging { FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> FlatMapGroupsWithStateExecHelper.legacyVersion.toString, STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> - StreamingAggregationStateManager.legacyVersion.toString + StreamingAggregationStateManager.legacyVersion.toString, + STREAMING_JOIN_STATE_FORMAT_VERSION.key -> + SymmetricHashJoinStateManager.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 50cf971e4ec3c..6bb4dc1672900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.SymmetricHashJoinStateManager.KeyToValuePair import org.apache.spark.sql.internal.SessionState import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -131,6 +132,7 @@ case class StreamingSymmetricHashJoinExec( stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, + stateFormatVersion: Int, left: SparkPlan, right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter { @@ -139,13 +141,20 @@ case class StreamingSymmetricHashJoinExec( rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], + stateFormatVersion: Int, left: SparkPlan, right: SparkPlan) = { this( leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), stateInfo = None, eventTimeWatermark = None, - stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) + stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) + } + + if (stateFormatVersion < 2 && joinType != Inner) { + throw new IllegalArgumentException("The query is using stream-stream outer join with state" + + s" format version ${stateFormatVersion} - correctness issue is discovered. Please discard" + + " the checkpoint and rerun the query. See SPARK-26154 for more details.") } private def throwBadJoinTypeException(): Nothing = { @@ -270,20 +279,30 @@ case class StreamingSymmetricHashJoinExec( // * Getting an iterator over the rows that have aged out on the left side. These rows are // candidates for being null joined. Note that to avoid doing two passes, this iterator // removes the rows from the state manager as they're processed. - // * Checking whether the current row matches a key in the right side state, and that key - // has any value which satisfies the filter function when joined. If it doesn't, - // we know we can join with null, since there was never (including this batch) a match - // within the watermark period. If it does, there must have been a match at some point, so - // we know we can't join with null. + // * (state format version 1) Checking whether the current row matches a key in the + // right side state, and that key has any value which satisfies the filter function when + // joined. If it doesn't, we know we can join with null, since there was never + // (including this batch) a match within the watermark period. If it does, there must have + // been a match at some point, so we know we can't join with null. + // * (state format version 2) We found edge-case of above approach which brings correctness + // issue, and had to take another approach (see SPARK-26154); now Spark stores 'matched' + // flag along with row, which is set to true when there's any matching row on the right. + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { rightSideJoiner.get(leftKeyValue.key).exists { rightValue => postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) } } val removedRowIter = leftSideJoiner.removeOldState() - val outerOutputIter = removedRowIter - .filterNot(pair => matchesWithRightSideState(pair)) - .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + val outerOutputIter = removedRowIter.filterNot { kv => + stateFormatVersion match { + case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value)) + case 2 => kv.matched + case _ => + throw new IllegalStateException("Unexpected state format version! " + + s"version $stateFormatVersion") + } + }.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) innerOutputIter ++ outerOutputIter case RightOuter => @@ -294,9 +313,15 @@ case class StreamingSymmetricHashJoinExec( } } val removedRowIter = rightSideJoiner.removeOldState() - val outerOutputIter = removedRowIter - .filterNot(pair => matchesWithLeftSideState(pair)) - .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + val outerOutputIter = removedRowIter.filterNot { kv => + stateFormatVersion match { + case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value)) + case 2 => kv.matched + case _ => + throw new IllegalStateException("Unexpected state format version! " + + s"version $stateFormatVersion") + } + }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) innerOutputIter ++ outerOutputIter case _ => throwBadJoinTypeException() @@ -395,7 +420,8 @@ case class StreamingSymmetricHashJoinExec( newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ private val joinStateManager = new SymmetricHashJoinStateManager( - joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) + joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value, + stateFormatVersion) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { @@ -445,16 +471,9 @@ case class StreamingSymmetricHashJoinExec( // the case of inner join). if (preJoinFilter(thisRow)) { val key = keyGenerator(thisRow) - val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => - generateJoinedRow(thisRow, thatRow) - }.filter(postJoinFilter) - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) - if (shouldAddToState) { - joinStateManager.append(key, thisRow) - updatedStateRowsCount += 1 - } - outputIter + val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager + .getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter) + new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter) } else { joinSide match { case LeftSide if joinType == LeftOuter => @@ -467,6 +486,23 @@ case class StreamingSymmetricHashJoinExec( } } + private class AddingProcessedRowToStateCompletionIterator( + key: UnsafeRow, + thisRow: UnsafeRow, + subIter: Iterator[JoinedRow]) + extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) { + private val iteratorNotEmpty: Boolean = super.hasNext + + override def completion(): Unit = { + val shouldAddToState = // add only if both removal predicates do not match + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + if (shouldAddToState) { + joinStateManager.append(key, thisRow, matched = iteratorNotEmpty) + updatedStateRowsCount += 1 + } + } + } + /** * Get an iterator over the values stored in this joiner's state manager for the given key. * @@ -486,7 +522,7 @@ case class StreamingSymmetricHashJoinExec( * We do this to avoid requiring either two passes or full materialization when * processing the rows for outer join. */ - def removeOldState(): Iterator[UnsafeRowPair] = { + def removeOldState(): Iterator[KeyToValuePair] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 43f22803e7685..c10713734dcc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -23,10 +23,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.execution.streaming.state.SymmetricHashJoinStateManager.KeyToValuePair +import org.apache.spark.sql.types.{BooleanType, LongType, StructField, StructType} import org.apache.spark.util.NextIterator /** @@ -42,10 +44,14 @@ import org.apache.spark.util.NextIterator * @param stateInfo Information about how to retrieve the correct version of state * @param storeConf Configuration for the state store. * @param hadoopConf Hadoop configuration for reading state data from storage + * @param stateFormatVersion The version of format for state. * * Internally, the key -> multiple values is stored in two [[StateStore]]s. * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values - * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between (key, index) -> value + * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping; the mapping depends on the state + * format version: + * - version 1: [(key, index) -> value] + * - version 2: [(key, index) -> (value, matched)] * - Put: update count in KeyToNumValuesStore, * insert new (key, count) -> value in KeyWithIndexToValueStore * - Get: read count from KeyToNumValuesStore, @@ -54,7 +60,7 @@ import org.apache.spark.util.NextIterator * scan all keys in KeyToNumValuesStore to find keys that do match the predicate, * delete from key from KeyToNumValuesStore, delete values in KeyWithIndexToValueStore * - Remove state by condition on values: - * scan all [(key, index) -> value] in KeyWithIndexToValueStore to find values that match + * scan all elements in KeyWithIndexToValueStore to find values that match * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), * decrement corresponding num values in KeyToNumValuesStore @@ -65,7 +71,8 @@ class SymmetricHashJoinStateManager( joinKeys: Seq[Expression], stateInfo: Option[StatefulOperatorStateInfo], storeConf: StateStoreConf, - hadoopConf: Configuration) extends Logging { + hadoopConf: Configuration, + stateFormatVersion: Int) extends Logging { import SymmetricHashJoinStateManager._ @@ -82,23 +89,46 @@ class SymmetricHashJoinStateManager( } /** Append a new value to the key */ - def append(key: UnsafeRow, value: UnsafeRow): Unit = { + def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit = { val numExistingValues = keyToNumValues.get(key) - keyWithIndexToValue.put(key, numExistingValues, value) + keyWithIndexToValue.put(key, numExistingValues, value, matched) keyToNumValues.put(key, numExistingValues + 1) } + /** + * Get all the matched values for given join condition, with marking matched. + * This method is designed to mark joined rows properly without exposing internal index of row. + */ + def getJoinedRows( + key: UnsafeRow, + generateJoinedRow: InternalRow => JoinedRow, + predicate: JoinedRow => Boolean): Iterator[JoinedRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue => + val joinedRow = generateJoinedRow(keyIdxToValue.value) + if (predicate(joinedRow)) { + if (!keyIdxToValue.matched) { + keyWithIndexToValue.put(key, keyIdxToValue.valueIndex, keyIdxToValue.value, + matched = true) + } + joinedRow + } else { + null + } + }.filter(_ != null) + } + /** * Remove using a predicate on keys. * - * This produces an iterator over the (key, value) pairs satisfying condition(key), where the - * underlying store is updated as a side-effect of producing next. + * This produces an iterator over the (key, value, matched) tuples satisfying condition(key), + * where the underlying store is updated as a side-effect of producing next. * * This implies the iterator must be consumed fully without any other operations on this manager * or the underlying store being interleaved. */ - def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { - new NextIterator[UnsafeRowPair] { + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair] = { + new NextIterator[KeyToValuePair] { private val allKeyToNumValues = keyToNumValues.iterator @@ -107,15 +137,15 @@ class SymmetricHashJoinStateManager( private def currentKey = currentKeyToNumValue.key - private val reusedPair = new UnsafeRowPair() + private val reusedRet = new KeyToValuePair() - private def getAndRemoveValue() = { + private def getAndRemoveValue(): KeyToValuePair = { val keyWithIndexAndValue = currentValues.next() keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) - reusedPair.withRows(currentKey, keyWithIndexAndValue.value) + reusedRet.withNew(currentKey, keyWithIndexAndValue.value, keyWithIndexAndValue.matched) } - override def getNext(): UnsafeRowPair = { + override def getNext(): KeyToValuePair = { // If there are more values for the current key, remove and return the next one. if (currentValues != null && currentValues.hasNext) { return getAndRemoveValue() @@ -126,8 +156,7 @@ class SymmetricHashJoinStateManager( while (allKeyToNumValues.hasNext) { currentKeyToNumValue = allKeyToNumValues.next() if (removalCondition(currentKey)) { - currentValues = keyWithIndexToValue.getAll( - currentKey, currentKeyToNumValue.numValue) + currentValues = keyWithIndexToValue.getAll(currentKey, currentKeyToNumValue.numValue) keyToNumValues.remove(currentKey) if (currentValues.hasNext) { @@ -148,18 +177,18 @@ class SymmetricHashJoinStateManager( /** * Remove using a predicate on values. * - * At a high level, this produces an iterator over the (key, value) pairs such that value - * satisfies the predicate, where producing an element removes the value from the state store - * and producing all elements with a given key updates it accordingly. + * At a high level, this produces an iterator over the (key, value, matched) tuples such that + * value satisfies the predicate, where producing an element removes the value from the + * state store and producing all elements with a given key updates it accordingly. * * This implies the iterator must be consumed fully without any other operations on this manager * or the underlying store being interleaved. */ - def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { - new NextIterator[UnsafeRowPair] { + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair] = { + new NextIterator[KeyToValuePair] { // Reuse this object to avoid creation+GC overhead. - private val reusedPair = new UnsafeRowPair() + private val reusedRet = new KeyToValuePair() private val allKeyToNumValues = keyToNumValues.iterator @@ -187,7 +216,7 @@ class SymmetricHashJoinStateManager( // Find the next value satisfying the condition, updating `currentKey` and `numValues` if // needed. Returns null when no value can be found. - private def findNextValueForIndex(): UnsafeRow = { + private def findNextValueForIndex(): ValueAndMatchPair = { // Loop across all values for the current key, and then all other keys, until we find a // value satisfying the removal condition. def hasMoreValuesForCurrentKey = currentKey != null && index < numValues @@ -195,9 +224,9 @@ class SymmetricHashJoinStateManager( while (hasMoreValuesForCurrentKey || hasMoreKeys) { if (hasMoreValuesForCurrentKey) { // First search the values for the current key. - val currentValue = keyWithIndexToValue.get(currentKey, index) - if (removalCondition(currentValue)) { - return currentValue + val valuePair = keyWithIndexToValue.get(currentKey, index) + if (removalCondition(valuePair.value)) { + return valuePair } else { index += 1 } @@ -219,7 +248,7 @@ class SymmetricHashJoinStateManager( return null } - override def getNext(): UnsafeRowPair = { + override def getNext(): KeyToValuePair = { val currentValue = findNextValueForIndex() // If there's no value, clean up and finish. There aren't any more available. @@ -233,8 +262,13 @@ class SymmetricHashJoinStateManager( // any hole. So we swap the last element into the hole and decrement numValues to shorten. // clean if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) - keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + val valuePairAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + if (valuePairAtMaxIndex != null) { + keyWithIndexToValue.put(currentKey, index, valuePairAtMaxIndex.value, + valuePairAtMaxIndex.matched) + } else { + keyWithIndexToValue.put(currentKey, index, null, false) + } keyWithIndexToValue.remove(currentKey, numValues - 1) } else { keyWithIndexToValue.remove(currentKey, 0) @@ -242,7 +276,7 @@ class SymmetricHashJoinStateManager( numValues -= 1 valueRemoved = true - return reusedPair.withRows(currentKey, currentValue) + return reusedRet.withNew(currentKey, currentValue.value, currentValue.matched) } override def close: Unit = {} @@ -294,7 +328,7 @@ class SymmetricHashJoinStateManager( joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) private val keyAttributes = keySchema.toAttributes private val keyToNumValues = new KeyToNumValuesStore() - private val keyWithIndexToValue = new KeyWithIndexToValueStore() + private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion) // Clean up any state store resources if necessary at the end of the task Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } @@ -335,7 +369,7 @@ class SymmetricHashJoinStateManager( * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. * Designed for object reuse. */ - private case class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) { + private class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) { def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = { this.key = newKey this.numValue = newNumValues @@ -380,18 +414,105 @@ class SymmetricHashJoinStateManager( * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. * Designed for object reuse. */ - private case class KeyWithIndexAndValue( - var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) { - def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = { + private class KeyWithIndexAndValue( + var key: UnsafeRow = null, + var valueIndex: Long = -1, + var value: UnsafeRow = null, + var matched: Boolean = false) { + + def withNew( + newKey: UnsafeRow, + newIndex: Long, + newValue: UnsafeRow, + newMatched: Boolean): this.type = { this.key = newKey this.valueIndex = newIndex this.value = newValue + this.matched = newMatched + this + } + + def withNew( + newKey: UnsafeRow, + newIndex: Long, + newValue: ValueAndMatchPair): this.type = { + this.key = newKey + this.valueIndex = newIndex + if (newValue != null) { + this.value = newValue.value + this.matched = newValue.matched + } else { + this.value = null + this.matched = false + } this } } - /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ - private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) { + private trait KeyWithIndexToValueRowConverter { + def valueAttributes: Seq[Attribute] + + def convertValue(value: UnsafeRow): ValueAndMatchPair + + def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow + } + + private object KeyWithIndexToValueRowConverter { + def create(version: Int): KeyWithIndexToValueRowConverter = version match { + case 1 => new KeyWithIndexToValueRowConverterFormatV1() + case 2 => new KeyWithIndexToValueRowConverterFormatV2() + case _ => throw new IllegalArgumentException("Incorrect state format version! " + + s"version $version") + } + } + + private class KeyWithIndexToValueRowConverterFormatV1 extends KeyWithIndexToValueRowConverter { + override val valueAttributes: Seq[Attribute] = inputValueAttributes + + override def convertValue(value: UnsafeRow): ValueAndMatchPair = { + if (value != null) ValueAndMatchPair(value, false) else null + } + + override def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow = value + } + + private class KeyWithIndexToValueRowConverterFormatV2 extends KeyWithIndexToValueRowConverter { + private val valueWithMatchedExprs = inputValueAttributes :+ Literal(true) + private val indexOrdinalInValueWithMatchedRow = inputValueAttributes.size + + private val valueWithMatchedRowGenerator = UnsafeProjection.create(valueWithMatchedExprs, + inputValueAttributes) + + override val valueAttributes: Seq[Attribute] = inputValueAttributes :+ + AttributeReference("matched", BooleanType)() + + // Projection to generate key row from (value + matched) row + private val valueRowGenerator = UnsafeProjection.create( + inputValueAttributes, valueAttributes) + + override def convertValue(value: UnsafeRow): ValueAndMatchPair = { + if (value != null) { + ValueAndMatchPair(valueRowGenerator(value), + value.getBoolean(indexOrdinalInValueWithMatchedRow)) + } else { + null + } + } + + override def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow = { + val row = valueWithMatchedRowGenerator(value) + row.setBoolean(indexOrdinalInValueWithMatchedRow, matched) + row + } + } + + /** + * A wrapper around a [[StateStore]] that stores the mapping; the mapping depends on the + * state format version - please refer implementations of [[KeyWithIndexToValueRowConverter]]. + */ + private class KeyWithIndexToValueStore(stateFormatVersion: Int) + extends StateStoreHandler(KeyWithIndexToValueType) { + private val keyWithIndexExprs = keyAttributes :+ Literal(1L) private val keyWithIndexSchema = keySchema.add("index", LongType) private val indexOrdinalInKeyWithIndexRow = keyAttributes.size @@ -403,10 +524,13 @@ class SymmetricHashJoinStateManager( private val keyRowGenerator = UnsafeProjection.create( keyAttributes, keyAttributes :+ AttributeReference("index", LongType)()) - protected val stateStore = getStateStore(keyWithIndexSchema, inputValueAttributes.toStructType) + private val valueRowConverter = KeyWithIndexToValueRowConverter.create(stateFormatVersion) + + protected val stateStore = getStateStore(keyWithIndexSchema, + valueRowConverter.valueAttributes.toStructType) - def get(key: UnsafeRow, valueIndex: Long): UnsafeRow = { - stateStore.get(keyWithIndexRow(key, valueIndex)) + def get(key: UnsafeRow, valueIndex: Long): ValueAndMatchPair = { + valueRowConverter.convertValue(stateStore.get(keyWithIndexRow(key, valueIndex))) } /** @@ -423,8 +547,8 @@ class SymmetricHashJoinStateManager( null } else { val keyWithIndex = keyWithIndexRow(key, index) - val value = stateStore.get(keyWithIndex) - keyWithIndexAndValue.withNew(key, index, value) + val valuePair = valueRowConverter.convertValue(stateStore.get(keyWithIndex)) + keyWithIndexAndValue.withNew(key, index, valuePair) index += 1 keyWithIndexAndValue } @@ -435,9 +559,10 @@ class SymmetricHashJoinStateManager( } /** Put new value for key at the given index */ - def put(key: UnsafeRow, valueIndex: Long, value: UnsafeRow): Unit = { + def put(key: UnsafeRow, valueIndex: Long, value: UnsafeRow, matched: Boolean): Unit = { val keyWithIndex = keyWithIndexRow(key, valueIndex) - stateStore.put(keyWithIndex, value) + val valueWithMatched = valueRowConverter.convertToValueRow(value, matched) + stateStore.put(keyWithIndex, valueWithMatched) } /** @@ -460,8 +585,9 @@ class SymmetricHashJoinStateManager( def iterator: Iterator[KeyWithIndexAndValue] = { val keyWithIndexAndValue = new KeyWithIndexAndValue() stateStore.getRange(None, None).map { pair => + val valuePair = valueRowConverter.convertValue(pair.value) keyWithIndexAndValue.withNew( - keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), pair.value) + keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), valuePair) keyWithIndexAndValue } } @@ -476,6 +602,8 @@ class SymmetricHashJoinStateManager( } object SymmetricHashJoinStateManager { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) @@ -497,4 +625,35 @@ object SymmetricHashJoinStateManager { private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { s"$joinSide-$storeType" } + + /** Helper class for representing data (value, matched). */ + case class ValueAndMatchPair(value: UnsafeRow, matched: Boolean) + + /** + * Helper class for representing data key to (value, matched). + * Designed for object reuse. + */ + case class KeyToValuePair( + var key: UnsafeRow = null, + var value: UnsafeRow = null, + var matched: Boolean = false) { + def withNew(newKey: UnsafeRow, newValue: UnsafeRow, newMatched: Boolean): this.type = { + this.key = newKey + this.value = newValue + this.matched = newMatched + this + } + + def withNew(newKey: UnsafeRow, newValue: ValueAndMatchPair): this.type = { + this.key = newKey + if (newValue != null) { + this.value = newValue.value + this.matched = newValue.matched + } else { + this.value = null + this.matched = false + } + this + } + } } diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/metadata new file mode 100644 index 0000000000000..543f156048abe --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/metadata @@ -0,0 +1 @@ +{"id":"1ab1ee6f-993c-4a51-824c-1c7cc8202f62"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/offsets/0 new file mode 100644 index 0000000000000..63dba425b7e16 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1548845804202,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/0/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..2cdf645d3a406 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..9c69d01231196 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/1/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..4e421cd377fb6 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..edc7a97408aaa Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..4e421cd377fb6 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..edc7a97408aaa Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/2/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..859c2b1315a5e Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..7535621b3adb2 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/3/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..0bdaf341003b9 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..f17037b3c5218 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..0bdaf341003b9 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..f17037b3c5218 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join/state/0/4/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index c0216a2ef3e61..b40f8df22b586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -38,9 +38,14 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter spark.streams.stateStoreCoordinator // initialize the lazy coordinator } + SymmetricHashJoinStateManager.supportedVersions.foreach { version => + test(s"StreamingJoinStateManager V${version} - all operations") { + testAllOperations(version) + } + } - test("SymmetricHashJoinStateManager - all operations") { - withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager => + private def testAllOperations(stateFormatVersion: Int): Unit = { + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager => implicit val mgr = manager assert(get(20) === Seq.empty) // initially empty @@ -123,7 +128,8 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = { - manager.append(toJoinKeyRow(key), toInputValue(value)) + // we only put matched = false for simplicity - StreamingJoinSuite will test the functionality + manager.append(toJoinKeyRow(key), toInputValue(value), matched = false) } def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { @@ -156,13 +162,15 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter def withJoinStateManager( inputValueAttribs: Seq[Attribute], - joinKeyExprs: Seq[Expression])(f: SymmetricHashJoinStateManager => Unit): Unit = { + joinKeyExprs: Seq[Expression], + stateFormatVersion: Int)(f: SymmetricHashJoinStateManager => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) val manager = new SymmetricHashJoinStateManager( - LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) + LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration, + stateFormatVersion) try { f(manager) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 42fe9f34ee3ec..ae6a4ecb7a6da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.streaming -import java.util.UUID +import java.io.File +import java.util.{Locale, UUID} import scala.util.Random +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation @@ -31,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} +import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -418,6 +420,63 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input2, 1.to(1000): _*), CheckAnswer(1.to(1000): _*)) } + + test("SPARK-26187 restore the stream-stream inner join query from Spark 2.4") { + val inputStream = MemoryStream[(Int, Long)] + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("rightId = leftId AND rightTime >= leftTime AND " + + "rightTime <= leftTime + interval 5 seconds"), + joinType = "inner") + .select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.4.0-streaming-join/").toURI + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + inputStream.addData((1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)) + + testStream(query)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + /* + Note: The checkpoint was generated using the following input in Spark version 2.4.0 + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + // batch 1 - global watermark = 0 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L) + // right: (2, 2L), (4, 4L) + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + */ + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + // batch 2: same result as above test + CheckNewAnswer((6, 6L, 6, 6L), (8, 8L, 8, 8L), (10, 10L, 10, 10L)), + assertNumStateRows(11, 6), + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { + case f: StreamingSymmetricHashJoinExec => f + } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + } + ) + } } @@ -712,5 +771,167 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 2, updated = 2) ) } + + test("SPARK-26187 self left outer join should not return outer nulls for already matched rows") { + val inputStream = MemoryStream[(Int, Long)] + + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("leftId = rightId AND rightTime >= leftTime AND " + + "rightTime <= leftTime + interval 5 seconds"), + joinType = "leftOuter") + .select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + + testStream(query)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + // batch 1 - global watermark = 0 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L) + // right: (2, 2L), (4, 4L) + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + // batch 2 - global watermark = 5 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L), (6, 6L), (7, 7L), (8, 8L), + // (9, 9L), (10, 10L) + // right: (6, 6L), (8, 8L), (10, 10L) + // states evicted + // left: nothing (it waits for 5 seconds more than watermark due to join condition) + // right: (2, 2L), (4, 4L) + // NOTE: look for evicted rows in right which are not evicted from left - they were + // properly joined in batch 1 + CheckNewAnswer((6, 6L, 6, 6L), (8, 8L, 8, 8L), (10, 10L, 10, 10L)), + assertNumStateRows(13, 8), + + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + // batch 3 + // - global watermark = 9 <= min(9, 10) + // states + // left: (4, 4L), (5, 5L), (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L), (11, 11L), + // (12, 12L), (13, 13L), (14, 14L), (15, 15L) + // right: (10, 10L), (12, 12L), (14, 14L) + // states evicted + // left: (1, 1L), (2, 2L), (3, 3L) + // right: (6, 6L), (8, 8L) + CheckNewAnswer( + Row(12, 12L, 12, 12L), Row(14, 14L, 14, 14L), + Row(1, 1L, null, null), Row(3, 3L, null, null)), + assertNumStateRows(15, 7) + ) + } + + test("SPARK-26187 self right outer join should not return outer nulls for already matched rows") { + val inputStream = MemoryStream[(Int, Long)] + + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + // we're just flipping "left" and "right" from left outer join and apply right outer join + + val leftStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df.select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("leftId = rightId AND leftTime >= rightTime AND " + + "leftTime <= rightTime + interval 5 seconds"), + joinType = "rightOuter") + .select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + + // we can just flip left and right in the explanation of left outer query test + // to assume the status of right outer query, hence skip explaining here + testStream(query)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + CheckNewAnswer((6, 6L, 6, 6L), (8, 8L, 8, 8L), (10, 10L, 10, 10L)), + assertNumStateRows(13, 8), + + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + CheckNewAnswer( + Row(12, 12L, 12, 12L), Row(14, 14L, 14, 14L), + Row(null, null, 1, 1L), Row(null, null, 3, 3L)), + assertNumStateRows(15, 7) + ) + } + + test("SPARK-26187 restore the stream-stream outer join query from Spark 2.4") { + val inputStream = MemoryStream[(Int, Long)] + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("rightId = leftId AND rightTime >= leftTime AND " + + "rightTime <= leftTime + interval 5 seconds"), + joinType = "leftOuter") + .select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.4.0-streaming-join/").toURI + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + inputStream.addData((1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)) + + /* + Note: The checkpoint was generated using the following input in Spark version 2.4.0 + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + // batch 1 - global watermark = 0 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L) + // right: (2, 2L), (4, 4L) + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + */ + + // we just fail the query if the checkpoint was create from less than Spark 3.0 + val e = intercept[StreamingQueryException] { + val writer = query.writeStream.format("console") + .option("checkpointLocation", checkpointDir.getAbsolutePath).start() + inputStream.addData((7, 7L), (8, 8L)) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + assert(e.getMessage.toLowerCase(Locale.ROOT) + .contains("the query is using stream-stream outer join with state format version 1")) + } }