diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f577f015973d..fc6d1e37ae944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -965,6 +965,16 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_JOIN_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.join.stateFormatVersion") + .internal() + .doc("State format version used by streaming join operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 831fc73634869..f40c0340c7e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -460,8 +460,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if left.isStreaming && right.isStreaming => - new StreamingSymmetricHashJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + val stateVersion = conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) + + new StreamingSymmetricHashJoinExec(leftKeys, rightKeys, joinType, condition, + stateVersion, planLater(left), planLater(right)) :: Nil case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index b6fa2e9dc3612..a479adbd39259 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,10 +23,10 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} +import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} - /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -91,7 +91,8 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, + STREAMING_JOIN_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -108,7 +109,9 @@ object OffsetSeqMetadata extends Logging { FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> FlatMapGroupsWithStateExecHelper.legacyVersion.toString, STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> - StreamingAggregationStateManager.legacyVersion.toString + StreamingAggregationStateManager.legacyVersion.toString, + STREAMING_JOIN_STATE_FORMAT_VERSION.key -> + StreamingJoinStateManager.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 50cf971e4ec3c..8c436e6305868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager +import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager._ import org.apache.spark.sql.internal.SessionState import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -131,6 +133,7 @@ case class StreamingSymmetricHashJoinExec( stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, + stateFormatVersion: Int, left: SparkPlan, right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter { @@ -139,13 +142,14 @@ case class StreamingSymmetricHashJoinExec( rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], + stateFormatVersion: Int, left: SparkPlan, right: SparkPlan) = { this( leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), stateInfo = None, eventTimeWatermark = None, - stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) + stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) } private def throwBadJoinTypeException(): Nothing = { @@ -200,7 +204,8 @@ case class StreamingSymmetricHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator - val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + val stateStoreNames = StreamingJoinStateManager.allStateStoreNames(stateFormatVersion, + LeftSide, RightSide) left.execute().stateStoreAwareZipPartitions( right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions) } @@ -223,7 +228,6 @@ case class StreamingSymmetricHashJoinExec( val updateStartTimeNs = System.nanoTime val joinedRow = new JoinedRow - val postJoinFilter = newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ val leftSideJoiner = new OneSideHashJoiner( @@ -261,7 +265,6 @@ case class StreamingSymmetricHashJoinExec( val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion) - val outputIter: Iterator[InternalRow] = joinType match { case Inner => innerOutputIter @@ -280,10 +283,17 @@ case class StreamingSymmetricHashJoinExec( postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) } } + val removedRowIter = leftSideJoiner.removeOldState() - val outerOutputIter = removedRowIter - .filterNot(pair => matchesWithRightSideState(pair)) - .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + val outerOutputIter = removedRowIter.filterNot { kvAndMatched => + stateFormatVersion match { + case 1 => matchesWithRightSideState( + new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value)) + case 2 => kvAndMatched.matched.get + case _ => throw new IllegalStateException("Incorrect state format version! " + + s"version $stateFormatVersion") + } + }.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) innerOutputIter ++ outerOutputIter case RightOuter => @@ -293,10 +303,17 @@ case class StreamingSymmetricHashJoinExec( postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) } } + val removedRowIter = rightSideJoiner.removeOldState() - val outerOutputIter = removedRowIter - .filterNot(pair => matchesWithLeftSideState(pair)) - .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + val outerOutputIter = removedRowIter.filterNot { kvAndMatched => + stateFormatVersion match { + case 1 => matchesWithLeftSideState( + new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value)) + case 2 => kvAndMatched.matched.get + case _ => throw new IllegalStateException("Incorrect state format version! " + + s"version $stateFormatVersion") + } + }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) innerOutputIter ++ outerOutputIter case _ => throwBadJoinTypeException() @@ -394,8 +411,10 @@ case class StreamingSymmetricHashJoinExec( val preJoinFilter = newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ - private val joinStateManager = new SymmetricHashJoinStateManager( - joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) + private val joinStateManager = StreamingJoinStateManager.createStateManager( + joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value, + stateFormatVersion) + private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { @@ -445,16 +464,11 @@ case class StreamingSymmetricHashJoinExec( // the case of inner join). if (preJoinFilter(thisRow)) { val key = keyGenerator(thisRow) - val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => - generateJoinedRow(thisRow, thatRow) - }.filter(postJoinFilter) - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) - if (shouldAddToState) { - joinStateManager.append(key, thisRow) - updatedStateRowsCount += 1 - } - outputIter + + val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager + .getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter) + + new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter) } else { joinSide match { case LeftSide if joinType == LeftOuter => @@ -467,6 +481,31 @@ case class StreamingSymmetricHashJoinExec( } } + private class AddingProcessedRowToStateCompletionIterator( + key: UnsafeRow, + thisRow: UnsafeRow, + subIter: Iterator[JoinedRow]) + extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) { + private var iteratorNotEmpty: Boolean = false + + override def hasNext: Boolean = { + val ret = super.hasNext + if (ret && !iteratorNotEmpty) { + iteratorNotEmpty = true + } + ret + } + + override def completion(): Unit = { + val shouldAddToState = // add only if both removal predicates do not match + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + if (shouldAddToState) { + joinStateManager.append(key, thisRow, matched = iteratorNotEmpty) + updatedStateRowsCount += 1 + } + } + } + /** * Get an iterator over the values stored in this joiner's state manager for the given key. * @@ -486,7 +525,7 @@ case class StreamingSymmetricHashJoinExec( * We do this to avoid requiring either two passes or full materialization when * processing the rows for outer join. */ - def removeOldState(): Iterator[UnsafeRowPair] = { + def removeOldState(): Iterator[KeyToValueAndMatched] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala deleted file mode 100644 index 43f22803e7685..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ /dev/null @@ -1,500 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import java.util.Locale - -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} -import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ -import org.apache.spark.sql.types.{LongType, StructField, StructType} -import org.apache.spark.util.NextIterator - -/** - * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]]. - * The interface of this class is basically that of a multi-map: - * - Get: Returns an iterator of multiple values for given key - * - Append: Append a new value to the given key - * - Remove Data by predicate: Drop any state using a predicate condition on keys or values - * - * @param joinSide Defines the join side - * @param inputValueAttributes Attributes of the input row which will be stored as value - * @param joinKeys Expressions to generate rows that will be used to key the value rows - * @param stateInfo Information about how to retrieve the correct version of state - * @param storeConf Configuration for the state store. - * @param hadoopConf Hadoop configuration for reading state data from storage - * - * Internally, the key -> multiple values is stored in two [[StateStore]]s. - * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values - * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between (key, index) -> value - * - Put: update count in KeyToNumValuesStore, - * insert new (key, count) -> value in KeyWithIndexToValueStore - * - Get: read count from KeyToNumValuesStore, - * read each of the n values in KeyWithIndexToValueStore - * - Remove state by predicate on keys: - * scan all keys in KeyToNumValuesStore to find keys that do match the predicate, - * delete from key from KeyToNumValuesStore, delete values in KeyWithIndexToValueStore - * - Remove state by condition on values: - * scan all [(key, index) -> value] in KeyWithIndexToValueStore to find values that match - * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore - * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), - * decrement corresponding num values in KeyToNumValuesStore - */ -class SymmetricHashJoinStateManager( - val joinSide: JoinSide, - inputValueAttributes: Seq[Attribute], - joinKeys: Seq[Expression], - stateInfo: Option[StatefulOperatorStateInfo], - storeConf: StateStoreConf, - hadoopConf: Configuration) extends Logging { - - import SymmetricHashJoinStateManager._ - - /* - ===================================================== - Public methods - ===================================================== - */ - - /** Get all the values of a key */ - def get(key: UnsafeRow): Iterator[UnsafeRow] = { - val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues).map(_.value) - } - - /** Append a new value to the key */ - def append(key: UnsafeRow, value: UnsafeRow): Unit = { - val numExistingValues = keyToNumValues.get(key) - keyWithIndexToValue.put(key, numExistingValues, value) - keyToNumValues.put(key, numExistingValues + 1) - } - - /** - * Remove using a predicate on keys. - * - * This produces an iterator over the (key, value) pairs satisfying condition(key), where the - * underlying store is updated as a side-effect of producing next. - * - * This implies the iterator must be consumed fully without any other operations on this manager - * or the underlying store being interleaved. - */ - def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { - new NextIterator[UnsafeRowPair] { - - private val allKeyToNumValues = keyToNumValues.iterator - - private var currentKeyToNumValue: KeyAndNumValues = null - private var currentValues: Iterator[KeyWithIndexAndValue] = null - - private def currentKey = currentKeyToNumValue.key - - private val reusedPair = new UnsafeRowPair() - - private def getAndRemoveValue() = { - val keyWithIndexAndValue = currentValues.next() - keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) - reusedPair.withRows(currentKey, keyWithIndexAndValue.value) - } - - override def getNext(): UnsafeRowPair = { - // If there are more values for the current key, remove and return the next one. - if (currentValues != null && currentValues.hasNext) { - return getAndRemoveValue() - } - - // If there weren't any values left, try and find the next key that satisfies the removal - // condition and has values. - while (allKeyToNumValues.hasNext) { - currentKeyToNumValue = allKeyToNumValues.next() - if (removalCondition(currentKey)) { - currentValues = keyWithIndexToValue.getAll( - currentKey, currentKeyToNumValue.numValue) - keyToNumValues.remove(currentKey) - - if (currentValues.hasNext) { - return getAndRemoveValue() - } - } - } - - // We only reach here if there were no satisfying keys left, which means we're done. - finished = true - return null - } - - override def close: Unit = {} - } - } - - /** - * Remove using a predicate on values. - * - * At a high level, this produces an iterator over the (key, value) pairs such that value - * satisfies the predicate, where producing an element removes the value from the state store - * and producing all elements with a given key updates it accordingly. - * - * This implies the iterator must be consumed fully without any other operations on this manager - * or the underlying store being interleaved. - */ - def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { - new NextIterator[UnsafeRowPair] { - - // Reuse this object to avoid creation+GC overhead. - private val reusedPair = new UnsafeRowPair() - - private val allKeyToNumValues = keyToNumValues.iterator - - private var currentKey: UnsafeRow = null - private var numValues: Long = 0L - private var index: Long = 0L - private var valueRemoved: Boolean = false - - // Push the data for the current key to the numValues store, and reset the tracking variables - // to their empty state. - private def updateNumValueForCurrentKey(): Unit = { - if (valueRemoved) { - if (numValues >= 1) { - keyToNumValues.put(currentKey, numValues) - } else { - keyToNumValues.remove(currentKey) - } - } - - currentKey = null - numValues = 0 - index = 0 - valueRemoved = false - } - - // Find the next value satisfying the condition, updating `currentKey` and `numValues` if - // needed. Returns null when no value can be found. - private def findNextValueForIndex(): UnsafeRow = { - // Loop across all values for the current key, and then all other keys, until we find a - // value satisfying the removal condition. - def hasMoreValuesForCurrentKey = currentKey != null && index < numValues - def hasMoreKeys = allKeyToNumValues.hasNext - while (hasMoreValuesForCurrentKey || hasMoreKeys) { - if (hasMoreValuesForCurrentKey) { - // First search the values for the current key. - val currentValue = keyWithIndexToValue.get(currentKey, index) - if (removalCondition(currentValue)) { - return currentValue - } else { - index += 1 - } - } else if (hasMoreKeys) { - // If we can't find a value for the current key, cleanup and start looking at the next. - // This will also happen the first time the iterator is called. - updateNumValueForCurrentKey() - - val currentKeyToNumValue = allKeyToNumValues.next() - currentKey = currentKeyToNumValue.key - numValues = currentKeyToNumValue.numValue - } else { - // Should be unreachable, but in any case means a value couldn't be found. - return null - } - } - - // We tried and failed to find the next value. - return null - } - - override def getNext(): UnsafeRowPair = { - val currentValue = findNextValueForIndex() - - // If there's no value, clean up and finish. There aren't any more available. - if (currentValue == null) { - updateNumValueForCurrentKey() - finished = true - return null - } - - // The backing store is arraylike - we as the caller are responsible for filling back in - // any hole. So we swap the last element into the hole and decrement numValues to shorten. - // clean - if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) - keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) - keyWithIndexToValue.remove(currentKey, numValues - 1) - } else { - keyWithIndexToValue.remove(currentKey, 0) - } - numValues -= 1 - valueRemoved = true - - return reusedPair.withRows(currentKey, currentValue) - } - - override def close: Unit = {} - } - } - - /** Commit all the changes to all the state stores */ - def commit(): Unit = { - keyToNumValues.commit() - keyWithIndexToValue.commit() - } - - /** Abort any changes to the state stores if needed */ - def abortIfNeeded(): Unit = { - keyToNumValues.abortIfNeeded() - keyWithIndexToValue.abortIfNeeded() - } - - /** Get the combined metrics of all the state stores */ - def metrics: StateStoreMetrics = { - val keyToNumValuesMetrics = keyToNumValues.metrics - val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics - def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc" - - StateStoreMetrics( - keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once - keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, - keyWithIndexToValueMetrics.customMetrics.map { - case (s @ StateStoreCustomSumMetric(_, desc), value) => - s.copy(desc = newDesc(desc)) -> value - case (s @ StateStoreCustomSizeMetric(_, desc), value) => - s.copy(desc = newDesc(desc)) -> value - case (s @ StateStoreCustomTimingMetric(_, desc), value) => - s.copy(desc = newDesc(desc)) -> value - case (s, _) => - throw new IllegalArgumentException( - s"Unknown state store custom metric is found at metrics: $s") - } - ) - } - - /* - ===================================================== - Private methods and inner classes - ===================================================== - */ - - private val keySchema = StructType( - joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) - private val keyAttributes = keySchema.toAttributes - private val keyToNumValues = new KeyToNumValuesStore() - private val keyWithIndexToValue = new KeyWithIndexToValueStore() - - // Clean up any state store resources if necessary at the end of the task - Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } - - /** Helper trait for invoking common functionalities of a state store. */ - private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { - - /** StateStore that the subclasses of this class is going to operate on */ - protected def stateStore: StateStore - - def commit(): Unit = { - stateStore.commit() - logDebug("Committed, metrics = " + stateStore.metrics) - } - - def abortIfNeeded(): Unit = { - if (!stateStore.hasCommitted) { - logInfo(s"Aborted store ${stateStore.id}") - stateStore.abort() - } - } - - def metrics: StateStoreMetrics = stateStore.metrics - - /** Get the StateStore with the given schema */ - protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { - val storeProviderId = StateStoreProviderId( - stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType)) - val store = StateStore.get( - storeProviderId, keySchema, valueSchema, None, - stateInfo.get.storeVersion, storeConf, hadoopConf) - logInfo(s"Loaded store ${store.id}") - store - } - } - - /** - * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. - * Designed for object reuse. - */ - private case class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) { - def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = { - this.key = newKey - this.numValue = newNumValues - this - } - } - - - /** A wrapper around a [[StateStore]] that stores [key -> number of values]. */ - private class KeyToNumValuesStore extends StateStoreHandler(KeyToNumValuesType) { - private val longValueSchema = new StructType().add("value", "long") - private val longToUnsafeRow = UnsafeProjection.create(longValueSchema) - private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema)) - protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema) - - /** Get the number of values the key has */ - def get(key: UnsafeRow): Long = { - val longValueRow = stateStore.get(key) - if (longValueRow != null) longValueRow.getLong(0) else 0L - } - - /** Set the number of values the key has */ - def put(key: UnsafeRow, numValues: Long): Unit = { - require(numValues > 0) - valueRow.setLong(0, numValues) - stateStore.put(key, valueRow) - } - - def remove(key: UnsafeRow): Unit = { - stateStore.remove(key) - } - - def iterator: Iterator[KeyAndNumValues] = { - val keyAndNumValues = new KeyAndNumValues() - stateStore.getRange(None, None).map { case pair => - keyAndNumValues.withNew(pair.key, pair.value.getLong(0)) - } - } - } - - /** - * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. - * Designed for object reuse. - */ - private case class KeyWithIndexAndValue( - var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) { - def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = { - this.key = newKey - this.valueIndex = newIndex - this.value = newValue - this - } - } - - /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ - private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) { - private val keyWithIndexExprs = keyAttributes :+ Literal(1L) - private val keyWithIndexSchema = keySchema.add("index", LongType) - private val indexOrdinalInKeyWithIndexRow = keyAttributes.size - - // Projection to generate (key + index) row from key row - private val keyWithIndexRowGenerator = UnsafeProjection.create(keyWithIndexExprs, keyAttributes) - - // Projection to generate key row from (key + index) row - private val keyRowGenerator = UnsafeProjection.create( - keyAttributes, keyAttributes :+ AttributeReference("index", LongType)()) - - protected val stateStore = getStateStore(keyWithIndexSchema, inputValueAttributes.toStructType) - - def get(key: UnsafeRow, valueIndex: Long): UnsafeRow = { - stateStore.get(keyWithIndexRow(key, valueIndex)) - } - - /** - * Get all values and indices for the provided key. - * Should not return null. - */ - def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { - val keyWithIndexAndValue = new KeyWithIndexAndValue() - var index = 0 - new NextIterator[KeyWithIndexAndValue] { - override protected def getNext(): KeyWithIndexAndValue = { - if (index >= numValues) { - finished = true - null - } else { - val keyWithIndex = keyWithIndexRow(key, index) - val value = stateStore.get(keyWithIndex) - keyWithIndexAndValue.withNew(key, index, value) - index += 1 - keyWithIndexAndValue - } - } - - override protected def close(): Unit = {} - } - } - - /** Put new value for key at the given index */ - def put(key: UnsafeRow, valueIndex: Long, value: UnsafeRow): Unit = { - val keyWithIndex = keyWithIndexRow(key, valueIndex) - stateStore.put(keyWithIndex, value) - } - - /** - * Remove key and value at given index. Note that this will create a hole in - * (key, index) and it is upto the caller to deal with it. - */ - def remove(key: UnsafeRow, valueIndex: Long): Unit = { - stateStore.remove(keyWithIndexRow(key, valueIndex)) - } - - /** Remove all values (i.e. all the indices) for the given key. */ - def removeAllValues(key: UnsafeRow, numValues: Long): Unit = { - var index = 0 - while (index < numValues) { - stateStore.remove(keyWithIndexRow(key, index)) - index += 1 - } - } - - def iterator: Iterator[KeyWithIndexAndValue] = { - val keyWithIndexAndValue = new KeyWithIndexAndValue() - stateStore.getRange(None, None).map { pair => - keyWithIndexAndValue.withNew( - keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), pair.value) - keyWithIndexAndValue - } - } - - /** Generated a row using the key and index */ - private def keyWithIndexRow(key: UnsafeRow, valueIndex: Long): UnsafeRow = { - val row = keyWithIndexRowGenerator(key) - row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex) - row - } - } -} - -object SymmetricHashJoinStateManager { - - def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) - for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { - getStateStoreName(joinSide, stateStoreType) - } - } - - private sealed trait StateStoreType - - private case object KeyToNumValuesType extends StateStoreType { - override def toString(): String = "keyToNumValues" - } - - private case object KeyWithIndexToValueType extends StateStoreType { - override def toString(): String = "keyWithIndexToValue" - } - - private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { - s"$joinSide-$storeType" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StateStoreHandlers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StateStoreHandlers.scala new file mode 100644 index 0000000000000..67de2d84d6542 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StateStoreHandlers.scala @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state.join + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinSide +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager._ +import org.apache.spark.sql.types.{BooleanType, LongType, StructType} +import org.apache.spark.util.NextIterator + +/** Helper trait for invoking common functionalities of a state store. */ +private[sql] abstract class StateStoreHandler( + stateStoreType: StateStoreType, + joinSide: JoinSide, + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) extends Logging { + /** StateStore that the subclasses of this class is going to operate on */ + protected def stateStore: StateStore + + def commit(): Unit = { + stateStore.commit() + logDebug("Committed, metrics = " + stateStore.metrics) + } + + def abortIfNeeded(): Unit = { + if (!stateStore.hasCommitted) { + logInfo(s"Aborted store ${stateStore.id}") + stateStore.abort() + } + } + + def metrics: StateStoreMetrics = stateStore.metrics + + /** Get the StateStore with the given schema */ + protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { + val storeProviderId = StateStoreProviderId( + stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType)) + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, None, + stateInfo.get.storeVersion, storeConf, hadoopConf) + logInfo(s"Loaded store ${store.id}") + store + } +} + +/** + * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. + * Designed for object reuse. + */ +private[sql] case class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) { + def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = { + this.key = newKey + this.numValue = newNumValues + this + } +} + +/** A wrapper around a [[StateStore]] that stores [key -> number of values]. */ +private[sql] class KeyToNumValuesStore( + storeType: StateStoreType, + joinSide: JoinSide, + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration, + keyAttributes: Seq[Attribute]) + extends StateStoreHandler( + storeType, + joinSide, + stateInfo, + storeConf, + hadoopConf) { + private val keySchema = keyAttributes.toStructType + private val longValueSchema = new StructType().add("value", "long") + private val longToUnsafeRow = UnsafeProjection.create(longValueSchema) + private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema)) + protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema) + + /** Get the number of values the key has */ + def get(key: UnsafeRow): Long = { + val longValueRow = stateStore.get(key) + if (longValueRow != null) longValueRow.getLong(0) else 0L + } + + /** Set the number of values the key has */ + def put(key: UnsafeRow, numValues: Long): Unit = { + require(numValues > 0) + valueRow.setLong(0, numValues) + stateStore.put(key, valueRow) + } + + def remove(key: UnsafeRow): Unit = { + stateStore.remove(key) + } + + def iterator: Iterator[KeyAndNumValues] = { + val keyAndNumValues = new KeyAndNumValues() + stateStore.getRange(None, None).map { case pair => + keyAndNumValues.withNew(pair.key, pair.value.getLong(0)) + } + } +} + +private[sql] abstract class KeyWithIndexToValueStore[T]( + storeType: StateStoreType, + joinSide: JoinSide, + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration, + keyAttributes: Seq[Attribute], + valueSchema: StructType) + extends StateStoreHandler( + storeType, + joinSide, + stateInfo, + storeConf, + hadoopConf) { + + /** + * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. + * Designed for object reuse. + */ + case class KeyWithIndexAndValue( + var key: UnsafeRow = null, + var valueIndex: Long = -1, + var value: T = null.asInstanceOf[T]) { + def withNew(newKey: UnsafeRow, newIndex: Long, newValue: T): this.type = { + this.key = newKey + this.valueIndex = newIndex + this.value = newValue + this + } + } + + private val keyWithIndexExprs = keyAttributes :+ Literal(1L) + private val keySchema = keyAttributes.toStructType + private val keyWithIndexSchema = keySchema.add("index", LongType) + private val indexOrdinalInKeyWithIndexRow = keyAttributes.size + + private val keyWithIndexRowGenerator = UnsafeProjection.create(keyWithIndexExprs, keyAttributes) + + // Projection to generate key row from (key + index) row + private val keyRowGenerator = UnsafeProjection.create( + keyAttributes, keyAttributes :+ AttributeReference("index", LongType)()) + + protected val stateStore = getStateStore(keyWithIndexSchema, valueSchema) + + def get(key: UnsafeRow, valueIndex: Long): T = { + convertValue(stateStore.get(keyWithIndexRow(key, valueIndex))) + } + + /** + * Get all values and indices for the provided key. + * Should not return null. + */ + def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() + var index = 0 + new NextIterator[KeyWithIndexAndValue] { + override protected def getNext(): KeyWithIndexAndValue = { + if (index >= numValues) { + finished = true + null + } else { + val keyWithIndex = keyWithIndexRow(key, index) + val value = stateStore.get(keyWithIndex) + keyWithIndexAndValue.withNew(key, index, convertValue(value)) + index += 1 + keyWithIndexAndValue + } + } + + override protected def close(): Unit = {} + } + } + + /** Put new value for key at the given index */ + def put(key: UnsafeRow, valueIndex: Long, value: T): Unit = { + val keyWithIndex = keyWithIndexRow(key, valueIndex) + val row = convertToValueRow(value) + if (row != null) { + stateStore.put(keyWithIndex, row) + } + } + + /** + * Remove key and value at given index. Note that this will create a hole in + * (key, index) and it is upto the caller to deal with it. + */ + def remove(key: UnsafeRow, valueIndex: Long): Unit = { + stateStore.remove(keyWithIndexRow(key, valueIndex)) + } + + /** Remove all values (i.e. all the indices) for the given key. */ + def removeAllValues(key: UnsafeRow, numValues: Long): Unit = { + var index = 0 + while (index < numValues) { + stateStore.remove(keyWithIndexRow(key, index)) + index += 1 + } + } + + def iterator: Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() + stateStore.getRange(None, None).map { pair => + keyWithIndexAndValue.withNew(keyRowGenerator(pair.key), + pair.key.getLong(indexOrdinalInKeyWithIndexRow), convertValue(pair.value)) + keyWithIndexAndValue + } + } + + /** Generated a row using the key and index */ + protected def keyWithIndexRow(key: UnsafeRow, valueIndex: Long): UnsafeRow = { + val row = keyWithIndexRowGenerator(key) + row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex) + row + } + + protected def convertValue(value: UnsafeRow): T + protected def convertToValueRow(value: T): UnsafeRow +} + +/** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ +private[sql] class KeyWithIndexToRowValueStore( + storeType: StateStoreType, + joinSide: JoinSide, + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration, + keyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute]) + extends KeyWithIndexToValueStore[UnsafeRow]( + storeType, + joinSide, + stateInfo, + storeConf, + hadoopConf, + keyAttributes, + valueAttributes.toStructType) { + + override protected def convertValue(value: UnsafeRow): UnsafeRow = value + + override protected def convertToValueRow(value: UnsafeRow): UnsafeRow = value +} + +/** A wrapper around a [[StateStore]] that stores [(key, index) -> (value, matched)]. */ +private[sql] class KeyWithIndexToRowAndMatchedStore( + storeType: StateStoreType, + joinSide: JoinSide, + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration, + keyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute]) + extends KeyWithIndexToValueStore[(UnsafeRow, Boolean)]( + storeType, + joinSide, + stateInfo, + storeConf, + hadoopConf, + keyAttributes, + valueAttributes.toStructType.add("matched", BooleanType)) { + + private val valueWithMatchedExprs = valueAttributes :+ Literal(true) + private val indexOrdinalInValueWithMatchedRow = valueAttributes.size + + private val valueWithMatchedRowGenerator = UnsafeProjection.create(valueWithMatchedExprs, + valueAttributes) + + // Projection to generate key row from (value + matched) row + private val valueRowGenerator = UnsafeProjection.create( + valueAttributes, valueAttributes :+ AttributeReference("matched", BooleanType)()) + + override protected def convertValue(value: UnsafeRow): (UnsafeRow, Boolean) = { + if (value != null) { + (valueRowGenerator(value), value.getBoolean(indexOrdinalInValueWithMatchedRow)) + } else null + } + + override protected def convertToValueRow(valueAndMatched: (UnsafeRow, Boolean)): UnsafeRow = { + val (value, matched) = valueAndMatched + valueWithMatchedRow(value, matched) + } + + /** Generated a row using the value and matched */ + protected def valueWithMatchedRow(key: UnsafeRow, matched: Boolean): UnsafeRow = { + val row = valueWithMatchedRowGenerator(key) + row.setBoolean(indexOrdinalInValueWithMatchedRow, matched) + row + } +} + +object StateStoreHandlers { + def removeByKeyCondition[T]( + keyToNumValues: KeyToNumValuesStore, + keyWithIndexToValue: KeyWithIndexToValueStore[T], + convertFn: T => (UnsafeRow, Option[Boolean]), + removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] = { + new NextIterator[KeyToValueAndMatched] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: KeyAndNumValues = null + private var currentValues: Iterator[keyWithIndexToValue.KeyWithIndexAndValue] = null + + private def currentKey = currentKeyToNumValue.key + + private val reusedTuple = new KeyToValueAndMatched() + + private def getAndRemoveValue(): KeyToValueAndMatched = { + val keyWithIndexAndValue = currentValues.next() + keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) + val (row, matched) = convertFn(keyWithIndexAndValue.value) + reusedTuple.withNew(currentKey, row, matched) + } + + override def getNext(): KeyToValueAndMatched = { + // If there are more values for the current key, remove and return the next one. + if (currentValues != null && currentValues.hasNext) { + return getAndRemoveValue() + } + + // If there weren't any values left, try and find the next key that satisfies the removal + // condition and has values. + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = allKeyToNumValues.next() + if (removalCondition(currentKey)) { + currentValues = keyWithIndexToValue.getAll( + currentKey, currentKeyToNumValue.numValue) + keyToNumValues.remove(currentKey) + + if (currentValues.hasNext) { + return getAndRemoveValue() + } + } + } + + // We only reach here if there were no satisfying keys left, which means we're done. + finished = true + return null + } + + override def close: Unit = {} + } + } + + def removeByValueCondition[T]( + keyToNumValues: KeyToNumValuesStore, + keyWithIndexToValue: KeyWithIndexToValueStore[T], + convertFn: T => (UnsafeRow, Option[Boolean]), + removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] = { + new NextIterator[KeyToValueAndMatched] { + + // Reuse this object to avoid creation+GC overhead. + private val reusedTuple = new KeyToValueAndMatched() + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + private var valueRemoved: Boolean = false + + // Push the data for the current key to the numValues store, and reset the tracking variables + // to their empty state. + private def updateNumValueForCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } + } + + currentKey = null + numValues = 0 + index = 0 + valueRemoved = false + } + + // Find the next value satisfying the condition, updating `currentKey` and `numValues` if + // needed. Returns null when no value can be found. + private def findNextValueForIndex(): (UnsafeRow, Option[Boolean]) = { + // Loop across all values for the current key, and then all other keys, until we find a + // value satisfying the removal condition. + def hasMoreValuesForCurrentKey = currentKey != null && index < numValues + def hasMoreKeys = allKeyToNumValues.hasNext + while (hasMoreValuesForCurrentKey || hasMoreKeys) { + if (hasMoreValuesForCurrentKey) { + // First search the values for the current key. + val currentValue = keyWithIndexToValue.get(currentKey, index) + val (row, matched) = convertFn(currentValue) + if (removalCondition(row)) { + return (row, matched) + } else { + index += 1 + } + } else if (hasMoreKeys) { + // If we can't find a value for the current key, cleanup and start looking at the next. + // This will also happen the first time the iterator is called. + updateNumValueForCurrentKey() + + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue + } else { + // Should be unreachable, but in any case means a value couldn't be found. + return null + } + } + + // We tried and failed to find the next value. + return null + } + + override def getNext(): KeyToValueAndMatched = { + val currentValue = findNextValueForIndex() + + // If there's no value, clean up and finish. There aren't any more available. + if (currentValue == null) { + updateNumValueForCurrentKey() + finished = true + return null + } + + // The backing store is arraylike - we as the caller are responsible for filling back in + // any hole. So we swap the last element into the hole and decrement numValues to shorten. + // clean + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + keyWithIndexToValue.remove(currentKey, numValues - 1) + } else { + keyWithIndexToValue.remove(currentKey, 0) + } + numValues -= 1 + valueRemoved = true + + val (value, matched) = currentValue + return reusedTuple.withNew(currentKey, value, matched) + } + + override def close: Unit = {} + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManager.scala new file mode 100644 index 0000000000000..576aeb665d341 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManager.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state.join + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinSide +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreMetrics} + +/** + * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]]. + * The interface of this class is basically that of a multi-map: + * - Get: Returns an iterator of multiple values for given key + * - Append: Append a new value to the given key + * - Remove Data by predicate: Drop any state using a predicate condition on keys or values + * + * @param joinSide Defines the join side + * @param inputValueAttributes Attributes of the input row which will be stored as value + * @param joinKeys Expressions to generate rows that will be used to key the value rows + * @param stateInfo Information about how to retrieve the correct version of state + * @param storeConf Configuration for the state store. + * @param hadoopConf Hadoop configuration for reading state data from storage + */ +trait StreamingJoinStateManager extends Serializable { + import StreamingJoinStateManager._ + + /** Get all the values of a key */ + def get(key: UnsafeRow): Iterator[UnsafeRow] + + /** + * Get all the matched values for given join condition, with marking matched. + * This method is designed to mark joined rows properly without exposing internal index of row. + */ + def getJoinedRows( + key: UnsafeRow, + generateJoinedRow: InternalRow => JoinedRow, + predicate: JoinedRow => Boolean): Iterator[JoinedRow] + + /** Append a new value to the key, with marking matched */ + def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit + + /** + * Remove using a predicate on keys. + * + * This produces an iterator over the (key, value, matched) tuples satisfying condition(key), + * where the underlying store is updated as a side-effect of producing next. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. + */ + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] + + /** + * Remove using a predicate on values. + * + * At a high level, this produces an iterator over the (key, (value, matched)) pairs such that + * value satisfies the predicate, where producing an element removes the value from the state + * store and producing all elements with a given key updates it accordingly. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. + */ + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] + + /** Commit all the changes to all the state stores */ + def commit(): Unit + + /** Abort any changes to the state stores if needed */ + def abortIfNeeded(): Unit + + /** Get the combined metrics of all the state stores */ + def metrics: StateStoreMetrics +} + +object StreamingJoinStateManager { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + joinSide: JoinSide, + inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration, + stateFormatVersion: Int): StreamingJoinStateManager = stateFormatVersion match { + case 1 => new StreamingJoinStateManagerImplV1(joinSide, inputValueAttributes, joinKeys, + stateInfo, storeConf, hadoopConf) + case 2 => new StreamingJoinStateManagerImplV2(joinSide, inputValueAttributes, joinKeys, + stateInfo, storeConf, hadoopConf) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + + def allStateStoreNames(stateFormatVersion: Int, joinSides: JoinSide*): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = stateFormatVersion match { + case 1 => StreamingJoinStateManagerImplV1.allStateStoreTypes + case 2 => StreamingJoinStateManagerImplV2.allStateStoreTypes + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + + for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { + getStateStoreName(joinSide, stateStoreType) + } + } + + def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { + s"$joinSide-$storeType" + } + + /** + * Helper class for representing data key to (value, matched). + * Designed for object reuse. + */ + case class KeyToValueAndMatched( + var key: UnsafeRow = null, + var value: UnsafeRow = null, + var matched: Option[Boolean] = None) { + def withNew(newKey: UnsafeRow, newValue: UnsafeRow, newMatched: Option[Boolean]): this.type = { + this.key = newKey + this.value = newValue + this.matched = newMatched + this + } + } + + trait StateStoreType +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManagerImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManagerImpl.scala new file mode 100644 index 0000000000000..eac3d647a1218 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManagerImpl.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state.join + +import java.util.Locale + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinSide +import org.apache.spark.sql.execution.streaming.state.{StateStore, _} +import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager.{KeyToValueAndMatched, StateStoreType} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.NextIterator + +private[sql] abstract class BaseStreamingJoinStateManagerImpl( + protected val joinSide: JoinSide, + protected val inputValueAttributes: Seq[Attribute], + protected val joinKeys: Seq[Expression], + protected val stateInfo: Option[StatefulOperatorStateInfo], + protected val storeConf: StateStoreConf, + protected val hadoopConf: Configuration) + extends StreamingJoinStateManager { + + protected val keySchema = StructType( + joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) + protected val keyAttributes = keySchema.toAttributes + + // Clean up any state store resources if necessary at the end of the task + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } +} + +/** + * Please refer [[StreamingJoinStateManager]] on documentation which is not related to internal. + * + * Internally, the key -> multiple values is stored in two [[StateStore]]s. + * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values + * - Store 2 ([[KeyWithIndexToRowValueStore]]) maintains mapping between (key, index) -> value + * - Put: update count in KeyToNumValuesStore, + * insert new (key, count) -> value in KeyWithIndexToValueStore + * - Get: read count from KeyToNumValuesStore, + * read each of the n values in KeyWithIndexToValueStore + * - Remove state by predicate on keys: + * scan all keys in KeyToNumValuesStore to find keys that do match the predicate, + * delete from key from KeyToNumValuesStore, delete values in KeyWithIndexToValueStore + * - Remove state by condition on values: + * scan all [(key, index) -> value] in KeyWithIndexToValueStore to find values that match + * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore + * by overwriting with the value of (key, maxIndex), + * and removing [(key, maxIndex), decrement corresponding num values in + * KeyToNumValuesStore + */ +private[sql] class StreamingJoinStateManagerImplV1( + joinSide: JoinSide, + inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) + extends BaseStreamingJoinStateManagerImpl( + joinSide: JoinSide, + inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) { + + import StreamingJoinStateManagerImplV1._ + + private val keyToNumValues = new KeyToNumValuesStore(KeyToNumValuesType, joinSide, stateInfo, + storeConf, hadoopConf, keyAttributes) + private val keyWithIndexToValue = new KeyWithIndexToRowValueStore(KeyWithIndexToRowValueType, + joinSide, stateInfo, storeConf, hadoopConf, keyAttributes, inputValueAttributes) + + override def get(key: UnsafeRow): Iterator[UnsafeRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues).map(_.value) + } + + override def getJoinedRows( + key: UnsafeRow, + generateJoinedRow: InternalRow => JoinedRow, + predicate: JoinedRow => Boolean): Iterator[JoinedRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue => + generateJoinedRow(keyIdxToValue.value) + }.filter(predicate) + } + + override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit = { + // V1 doesn't leverage 'matched' information + val numExistingValues = keyToNumValues.get(key) + keyWithIndexToValue.put(key, numExistingValues, value) + keyToNumValues.put(key, numExistingValues + 1) + } + + override def removeByKeyCondition( + removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] = { + StateStoreHandlers.removeByKeyCondition(keyToNumValues, keyWithIndexToValue, + (row: UnsafeRow) => (row, None), removalCondition) + } + + override def removeByValueCondition( + removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] = { + StateStoreHandlers.removeByValueCondition(keyToNumValues, keyWithIndexToValue, + (row: UnsafeRow) => (row, None), removalCondition) + } + + override def commit(): Unit = { + keyToNumValues.commit() + keyWithIndexToValue.commit() + } + + override def abortIfNeeded(): Unit = { + keyToNumValues.abortIfNeeded() + keyWithIndexToValue.abortIfNeeded() + } + + override def metrics: StateStoreMetrics = { + val keyToNumValuesMetrics = keyToNumValues.metrics + val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics + def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc" + + StateStoreMetrics( + keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once + keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, + keyWithIndexToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomSizeMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomTimingMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") + } + ) + } +} + +private[sql] object StreamingJoinStateManagerImplV1 { + case object KeyToNumValuesType extends StateStoreType { + override def toString(): String = "keyToNumValues" + } + + case object KeyWithIndexToRowValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" + } + + def allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToRowValueType) +} + +/** + * Please refer [[StreamingJoinStateManager]] on documentation which is not related to internal. + * Please also refer [[StreamingJoinStateManagerImplV1]] on internal details. Here we will only + * describe on difference between StreamingJoinStateManagerImplV1 and this class. + * + * This class stores the key -> multiple values in three [[StateStore]]s. + * - Store 1 ([[KeyToNumValuesStore]]): same as StreamingJoinStateManagerImplV1, + * - Store 2 ([[KeyWithIndexToRowAndMatchedStore]]): maintains mapping between (key, index) -> + * (value, matched) - this only changes the type of value and operations remain same + * + * Operations in this class are aware of the change, and handle the mapping accordingly. + */ +private[sql] class StreamingJoinStateManagerImplV2( + joinSide: JoinSide, + inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) + extends BaseStreamingJoinStateManagerImpl( + joinSide: JoinSide, + inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) with Logging { + + import StreamingJoinStateManagerImplV2._ + + private val keyToNumValues = new KeyToNumValuesStore(KeyToNumValuesType, joinSide, stateInfo, + storeConf, hadoopConf, keyAttributes) + private val keyWithIndexToValue = new KeyWithIndexToRowAndMatchedStore(KeyWithIndexToRowValueType, + joinSide, stateInfo, storeConf, hadoopConf, keyAttributes, inputValueAttributes) + + override def get(key: UnsafeRow): Iterator[UnsafeRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues).map(_.value._1) + } + + override def getJoinedRows( + key: UnsafeRow, + generateJoinedRow: InternalRow => JoinedRow, + predicate: JoinedRow => Boolean): Iterator[JoinedRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue => + val joinedRow = generateJoinedRow(keyIdxToValue.value._1) + if (predicate(joinedRow)) { + val row = keyWithIndexToValue.get(key, keyIdxToValue.valueIndex) + if (!row._2) { + // only update when matched flag is false + keyWithIndexToValue.put(key, keyIdxToValue.valueIndex, row.copy(_2 = true)) + } + joinedRow + } else { + null + } + }.filter(_ != null) + } + + override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit = { + val numExistingValues = keyToNumValues.get(key) + keyWithIndexToValue.put(key, numExistingValues, (value, matched)) + keyToNumValues.put(key, numExistingValues + 1) + } + + override def removeByKeyCondition( + removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] = { + StateStoreHandlers.removeByKeyCondition(keyToNumValues, keyWithIndexToValue, + (rowAndMatched: (UnsafeRow, Boolean)) => (rowAndMatched._1, Some(rowAndMatched._2)), + removalCondition) + } + + override def removeByValueCondition( + removalCondition: UnsafeRow => Boolean): Iterator[KeyToValueAndMatched] = { + StateStoreHandlers.removeByValueCondition(keyToNumValues, keyWithIndexToValue, + (rowAndMatched: (UnsafeRow, Boolean)) => (rowAndMatched._1, Some(rowAndMatched._2)), + removalCondition) + } + + override def commit(): Unit = { + keyToNumValues.commit() + keyWithIndexToValue.commit() + } + + override def abortIfNeeded(): Unit = { + keyToNumValues.abortIfNeeded() + keyWithIndexToValue.abortIfNeeded() + } + + override def metrics: StateStoreMetrics = { + val keyToNumValuesMetrics = keyToNumValues.metrics + val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics + def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc" + + StateStoreMetrics( + keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once + keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, + keyWithIndexToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomSizeMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomTimingMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") + } + ) + } +} + +private[sql] object StreamingJoinStateManagerImplV2 { + case object KeyToNumValuesType extends StateStoreType { + override def toString(): String = "keyToNumValues" + } + + case object KeyWithIndexToRowValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" + } + + def allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToRowValueType) +} diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/metadata new file mode 100644 index 0000000000000..543f156048abe --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"1ab1ee6f-993c-4a51-824c-1c7cc8202f62"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/offsets/0 new file mode 100644 index 0000000000000..63dba425b7e16 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1548845804202,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/0/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..2cdf645d3a406 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..9c69d01231196 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/1/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..4e421cd377fb6 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..edc7a97408aaa Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..4e421cd377fb6 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..edc7a97408aaa Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/2/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..859c2b1315a5e Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..7535621b3adb2 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/3/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/left-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/left-keyToNumValues/1.delta new file mode 100644 index 0000000000000..0bdaf341003b9 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/left-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/left-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/left-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..f17037b3c5218 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/left-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/right-keyToNumValues/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/right-keyToNumValues/1.delta new file mode 100644 index 0000000000000..0bdaf341003b9 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/right-keyToNumValues/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/right-keyWithIndexToValue/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/right-keyWithIndexToValue/1.delta new file mode 100644 index 0000000000000..f17037b3c5218 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/state/0/4/right-keyWithIndexToValue/1.delta differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManagerSuite.scala similarity index 53% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManagerSuite.scala index c0216a2ef3e61..acc714c2c60fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/join/StreamingJoinStateManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming.state +package org.apache.spark.sql.execution.streaming.state.join import java.util.UUID @@ -23,83 +23,94 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types._ -class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter { +class StreamingJoinStateManagerSuite extends StreamTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' spark.streams.stateStoreCoordinator // initialize the lazy coordinator } + test("StreamingJoinStateManager V1 - all operations") { + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion = 1) { manager => + testAllOperations(manager) + } + } - test("SymmetricHashJoinStateManager - all operations") { - withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager => - implicit val mgr = manager - - assert(get(20) === Seq.empty) // initially empty - append(20, 2) - assert(get(20) === Seq(2)) // should first value correctly - assert(numRows === 1) - - append(20, 3) - assert(get(20) === Seq(2, 3)) // should append new values - append(20, 3) - assert(get(20) === Seq(2, 3, 3)) // should append another copy if same value added again - assert(numRows === 3) - - assert(get(30) === Seq.empty) - append(30, 1) - assert(get(30) === Seq(1)) - assert(get(20) === Seq(2, 3, 3)) // add another key-value should not affect existing ones - assert(numRows === 4) - - removeByKey(25) - assert(get(20) === Seq.empty) - assert(get(30) === Seq(1)) // should remove 20, not 30 - assert(numRows === 1) - - removeByKey(30) - assert(get(30) === Seq.empty) // should remove 30 - assert(numRows === 0) - - def appendAndTest(key: Int, values: Int*): Unit = { - values.foreach { value => append(key, value)} - require(get(key) === values) - } + test("StreamingJoinStateManager V2 - all operations") { + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion = 2) { manager => + testAllOperations(manager) + } + } - appendAndTest(40, 100, 200, 300) - appendAndTest(50, 125) - appendAndTest(60, 275) // prepare for testing removeByValue - assert(numRows === 5) - - removeByValue(125) - assert(get(40) === Seq(200, 300)) - assert(get(50) === Seq.empty) - assert(get(60) === Seq(275)) // should remove only some values, not all - assert(numRows === 3) - - append(40, 50) - assert(get(40) === Seq(50, 200, 300)) - assert(numRows === 4) - - removeByValue(200) - assert(get(40) === Seq(300)) - assert(get(60) === Seq(275)) // should remove only some values, not all - assert(numRows === 2) - - removeByValue(300) - assert(get(40) === Seq.empty) - assert(get(60) === Seq.empty) // should remove all values now - assert(numRows === 0) + private def testAllOperations(manager: StreamingJoinStateManager): Unit = { + implicit val mgr = manager + + assert(get(20) === Seq.empty) // initially empty + append(20, 2) + assert(get(20) === Seq(2)) // should first value correctly + assert(numRows === 1) + + append(20, 3) + assert(get(20) === Seq(2, 3)) // should append new values + append(20, 3) + assert(get(20) === Seq(2, 3, 3)) // should append another copy if same value added again + assert(numRows === 3) + + assert(get(30) === Seq.empty) + append(30, 1) + assert(get(30) === Seq(1)) + assert(get(20) === Seq(2, 3, 3)) // add another key-value should not affect existing ones + assert(numRows === 4) + + removeByKey(25) + assert(get(20) === Seq.empty) + assert(get(30) === Seq(1)) // should remove 20, not 30 + assert(numRows === 1) + + removeByKey(30) + assert(get(30) === Seq.empty) // should remove 30 + assert(numRows === 0) + + def appendAndTest(key: Int, values: Int*): Unit = { + values.foreach { value => append(key, value)} + require(get(key) === values) } + + appendAndTest(40, 100, 200, 300) + appendAndTest(50, 125) + appendAndTest(60, 275) // prepare for testing removeByValue + assert(numRows === 5) + + removeByValue(125) + assert(get(40) === Seq(200, 300)) + assert(get(50) === Seq.empty) + assert(get(60) === Seq(275)) // should remove only some values, not all + assert(numRows === 3) + + append(40, 50) + assert(get(40) === Seq(50, 200, 300)) + assert(numRows === 4) + + removeByValue(200) + assert(get(40) === Seq(300)) + assert(get(60) === Seq(275)) // should remove only some values, not all + assert(numRows === 2) + + removeByValue(300) + assert(get(40) === Seq.empty) + assert(get(60) === Seq.empty) // should remove all values now + assert(numRows === 0) } + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() val inputValueSchema = new StructType() .add(StructField("time", IntegerType, metadata = watermarkMetadata)) @@ -111,7 +122,6 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) val joinKeyGen = UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray) - def toInputValue(i: Int): UnsafeRow = { inputValueGen.apply(new GenericInternalRow(Array[Any](i, false))) } @@ -122,16 +132,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) - def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = { - manager.append(toJoinKeyRow(key), toInputValue(value)) + def append(key: Int, value: Int)(implicit manager: StreamingJoinStateManager): Unit = { + manager.append(toJoinKeyRow(key), toInputValue(value), matched = false) } - def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { + def get(key: Int)(implicit manager: StreamingJoinStateManager): Seq[Int] = { manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted } /** Remove keys (and corresponding values) where `time <= threshold` */ - def removeByKey(threshold: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { + def removeByKey(threshold: Long)(implicit manager: StreamingJoinStateManager): Unit = { val expr = LessThanOrEqual( BoundReference( @@ -142,27 +152,28 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } /** Remove values where `time <= threshold` */ - def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { + def removeByValue(watermark: Long)(implicit manager: StreamingJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) while (iter.hasNext) iter.next() } - def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { + def numRows(implicit manager: StreamingJoinStateManager): Long = { manager.metrics.numKeys } - def withJoinStateManager( inputValueAttribs: Seq[Attribute], - joinKeyExprs: Seq[Expression])(f: SymmetricHashJoinStateManager => Unit): Unit = { + joinKeyExprs: Seq[Expression], + stateFormatVersion: Int)(f: StreamingJoinStateManager => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) - val manager = new SymmetricHashJoinStateManager( - LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) + val manager = StreamingJoinStateManager.createStateManager( + LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration, + stateFormatVersion) try { f(manager) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 42fe9f34ee3ec..83eee3e6275b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -17,24 +17,20 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.UUID import scala.util.Random +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -712,5 +708,189 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 2, updated = 2) ) } + + test("SPARK-26187 self left outer join should not return outer nulls for already matched rows") { + val inputStream = MemoryStream[(Int, Long)] + + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("leftId = rightId AND rightTime >= leftTime AND " + + "rightTime <= leftTime + interval 5 seconds"), + joinType = "leftOuter") + .select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + + testStream(query)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + // batch 1 - global watermark = 0 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L) + // right: (2, 2L), (4, 4L) + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + // batch 2 - global watermark = 5 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L), (6, 6L), (7, 7L), (8, 8L), + // (9, 9L), (10, 10L) + // right: (6, 6L), (8, 8L), (10, 10L) + // states evicted + // left: nothing (it waits for 5 seconds more than watermark due to join condition) + // right: (2, 2L), (4, 4L) + // NOTE: look for evicted rows in right which are not evicted from left - they were + // properly joined in batch 1 + CheckNewAnswer((6, 6L, 6, 6L), (8, 8L, 8, 8L), (10, 10L, 10, 10L)), + assertNumStateRows(13, 8), + + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + // batch 3 + // - global watermark = 9 <= min(9, 10) + // states + // left: (4, 4L), (5, 5L), (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L), (11, 11L), + // (12, 12L), (13, 13L), (14, 14L), (15, 15L) + // right: (10, 10L), (12, 12L), (14, 14L) + // states evicted + // left: (1, 1L), (2, 2L), (3, 3L) + // right: (6, 6L), (8, 8L) + CheckNewAnswer( + Row(12, 12L, 12, 12L), Row(14, 14L, 14, 14L), + Row(1, 1L, null, null), Row(3, 3L, null, null)), + assertNumStateRows(15, 7) + ) + } + + test("SPARK-26187 self right outer join should not return outer nulls for already matched rows") { + val inputStream = MemoryStream[(Int, Long)] + + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + // we're just flipping "left" and "right" from left outer join and apply right outer join + + val leftStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df.select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("leftId = rightId AND leftTime >= rightTime AND " + + "leftTime <= rightTime + interval 5 seconds"), + joinType = "rightOuter") + .select(col("rightId"), col("rightTime").cast("int"), + col("leftId"), col("leftTime").cast("int")) + + // we can just flip left and right in the explanation of left outer query test + // to assume the status of right outer query, hence skip explaining here + testStream(query)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + CheckNewAnswer((6, 6L, 6, 6L), (8, 8L, 8, 8L), (10, 10L, 10, 10L)), + assertNumStateRows(13, 8), + + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + CheckNewAnswer( + Row(12, 12L, 12, 12L), Row(14, 14L, 14, 14L), + Row(1, 1L, null, null), Row(3, 3L, null, null)), + assertNumStateRows(15, 7) + ) + } + + test("SPARK-26187 restore the query - state format version 1") { + val inputStream = MemoryStream[(Int, Long)] + + val df = inputStream.toDS() + .select(col("_1").as("value"), col("_2").cast("timestamp").as("timestamp")) + + val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val query = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("rightId = leftId AND rightTime >= leftTime AND " + + "rightTime <= leftTime + interval 5 seconds"), + joinType = "leftOuter") + .select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.4.0-streaming-join-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputStream.addData((1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)) + + testStream(query)( + StartStream( + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "2")), + + /* + Note: The checkpoint was generated using the following input in Spark version 2.4.0 + + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + // batch 1 - global watermark = 0 + // states + // left: (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L) + // right: (2, 2L), (4, 4L) + CheckNewAnswer((2, 2L, 2, 2L), (4, 4L, 4, 4L)), + assertNumStateRows(7, 7), + */ + + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + // batch 2: same result as above test + CheckNewAnswer((6, 6L, 6, 6L), (8, 8L, 8, 8L), (10, 10L, 10, 10L)), + assertNumStateRows(13, 8), + + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { + case f: StreamingSymmetricHashJoinExec => f + } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + }, + + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + // batch 3: global watermark, remaining rows in states, evicted rows are all same + // The query is running with state format 1, which SPARK-26187 is not fixed. + // Hence (2, 2L, null, null) is also emitted as output as well. + CheckNewAnswer( + Row(12, 12L, 12, 12L), Row(14, 14L, 14, 14L), + Row(1, 1L, null, null), Row(3, 3L, null, null), + Row(2, 2L, null, null)), + assertNumStateRows(15, 7) + ) + } + }