diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 8b35822e83fac..828c06ab834ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -370,7 +370,8 @@ case class StateSourceOptions( readChangeFeedOptions: Option[ReadChangeFeedOptions], stateVarName: Option[String], readRegisteredTimers: Boolean, - flattenCollectionTypes: Boolean) { + flattenCollectionTypes: Boolean, + operatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { @@ -567,10 +568,37 @@ object StateSourceOptions extends DataSourceOptions { } } + val startBatchId = if (fromSnapshotOptions.isDefined) { + fromSnapshotOptions.get.snapshotStartBatchId + } else if (readChangeFeedOptions.isDefined) { + readChangeFeedOptions.get.changeStartBatchId + } else { + batchId.get + } + + val operatorStateUniqueIds = getOperatorStateUniqueIds( + sparkSession, + startBatchId, + operatorId, + resolvedCpLocation) + + if (operatorStateUniqueIds.isDefined) { + if (fromSnapshotOptions.isDefined) { + throw StateDataSourceErrors.invalidOptionValue( + SNAPSHOT_START_BATCH_ID, + "Snapshot reading is currently not supported with checkpoint v2.") + } + if (readChangeFeedOptions.isDefined) { + throw StateDataSourceErrors.invalidOptionValue( + READ_CHANGE_FEED, + "Read change feed is currently not supported with checkpoint v2.") + } + } + StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes) + stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds) } private def resolvedCheckpointLocation( @@ -589,6 +617,20 @@ object StateSourceOptions extends DataSourceOptions { } } + private def getOperatorStateUniqueIds( + session: SparkSession, + batchId: Long, + operatorId: Long, + checkpointLocation: String): Option[Array[Array[String]]] = { + val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog + val commitMetadata = commitLog.get(batchId) match { + case Some(commitMetadata) => commitMetadata + case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) + } + + commitMetadata.stateUniqueIds.flatMap(_.get(operatorId)) + } + // Modifies options due to external data. Returns modified options. // If this is a join operator specifying a store name using state format v3, // we need to modify the options. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 6402eba868ef0..ebef6e3dac552 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} @@ -95,6 +96,13 @@ abstract class StatePartitionReaderBase( schema, "value").asInstanceOf[StructType] } + protected val getStoreUniqueId : Option[String] = { + SymmetricHashJoinStateManager.getStateStoreCheckpointId( + storeName = partition.sourceOptions.storeName, + partitionId = partition.partition, + stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds) + } + protected lazy val provider: StateStoreProvider = { val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) @@ -113,7 +121,9 @@ abstract class StatePartitionReaderBase( val isInternal = partition.sourceOptions.readRegisteredTimers if (useColFamilies) { - val store = provider.getStore(partition.sourceOptions.batchId + 1) + val store = provider.getStore( + partition.sourceOptions.batchId + 1, + getStoreUniqueId) require(stateStoreColFamilySchemaOpt.isDefined) val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined) @@ -171,7 +181,11 @@ class StatePartitionReader( private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { - case None => provider.getReadStore(partition.sourceOptions.batchId + 1) + case None => + provider.getReadStore( + partition.sourceOptions.batchId + 1, + getStoreUniqueId + ) case Some(fromSnapshotOptions) => if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index f1415865db246..0f8a3b3b609f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -71,6 +71,28 @@ class StreamStreamJoinStatePartitionReader( throw StateDataSourceErrors.internalError("Unexpected join side for stream-stream read!") } + private val usesVirtualColumnFamilies = StreamStreamJoinStateHelper.usesVirtualColumnFamilies( + hadoopConf.value, + partition.sourceOptions.stateCheckpointLocation.toString, + partition.sourceOptions.operatorId) + + private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( + partition.partition, + partition.sourceOptions.operatorStateUniqueIds, + usesVirtualColumnFamilies) + + private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) { + stateStoreCheckpointIds.left.keyToNumValues + } else { + stateStoreCheckpointIds.right.keyToNumValues + } + + private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) { + stateStoreCheckpointIds.left.keyWithIndexToValue + } else { + stateStoreCheckpointIds.right.keyWithIndexToValue + } + /* * This is to handle the difference of schema across state format versions. The major difference * is whether we have added new field(s) in addition to the fields from input schema. @@ -85,10 +107,7 @@ class StreamStreamJoinStatePartitionReader( // column from the value schema to get the actual fields. if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) { // If checkpoint is using one store and virtual column families, version is 3 - if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies( - hadoopConf.value, - partition.sourceOptions.stateCheckpointLocation.toString, - partition.sourceOptions.operatorId)) { + if (usesVirtualColumnFamilies) { (valueSchema.dropRight(1), 3) } else { (valueSchema.dropRight(1), 2) @@ -130,8 +149,8 @@ class StreamStreamJoinStatePartitionReader( storeConf = storeConf, hadoopConf = hadoopConf.value, partitionId = partition.partition, - keyToNumValuesStateStoreCkptId = None, - keyWithIndexToValueStateStoreCkptId = None, + keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId, + keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId, formatVersion, skippedNullValueCount = None, useStateStoreCoordinator = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index f424d2892dfa1..ef37185ce4166 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -351,7 +351,7 @@ case class StreamingSymmetricHashJoinExec( assert(stateInfo.isDefined, "State info not defined") val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( - partitionId, stateInfo.get, useVirtualColumnFamilies) + partitionId, stateInfo.get.stateStoreCkptIds, useVirtualColumnFamilies) val inputSchema = left.output ++ right.output val postJoinFilter = @@ -363,12 +363,12 @@ case class StreamingSymmetricHashJoinExec( new OneSideHashJoiner( LeftSide, left.output, leftKeys, leftInputIter, condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId, - checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys, + checkpointIds.left.keyToNumValues, checkpointIds.left.keyWithIndexToValue, skippedNullValueCount, joinStateManagerStoreGenerator), new OneSideHashJoiner( RightSide, right.output, rightKeys, rightInputIter, condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId, - checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys, + checkpointIds.right.keyToNumValues, checkpointIds.right.keyWithIndexToValue, skippedNullValueCount, joinStateManagerStoreGenerator)) // Join one side input using the other side's buffered/state rows. Here is how it is done. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala index 6f02a17efe340..7b02a43cd5a9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala @@ -324,15 +324,15 @@ object StreamingSymmetricHashJoinHelper extends Logging { case class JoinerStateStoreCkptInfo( keyToNumValues: StateStoreCheckpointInfo, - valueToNumKeys: StateStoreCheckpointInfo) + keyWithIndexToValue: StateStoreCheckpointInfo) case class JoinStateStoreCkptInfo( left: JoinerStateStoreCkptInfo, right: JoinerStateStoreCkptInfo) case class JoinerStateStoreCheckpointId( - keyToNumValues: Option[String], - valueToNumKeys: Option[String]) + keyToNumValues: Option[String], + keyWithIndexToValue: Option[String]) case class JoinStateStoreCheckpointId( left: JoinerStateStoreCheckpointId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 4ba6dcced5335..c0965747722e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -1135,17 +1135,17 @@ object SymmetricHashJoinStateManager { val ckptIds = joinCkptInfo.left.keyToNumValues.stateStoreCkptId.map( Array( _, - joinCkptInfo.left.valueToNumKeys.stateStoreCkptId.get, + joinCkptInfo.left.keyWithIndexToValue.stateStoreCkptId.get, joinCkptInfo.right.keyToNumValues.stateStoreCkptId.get, - joinCkptInfo.right.valueToNumKeys.stateStoreCkptId.get + joinCkptInfo.right.keyWithIndexToValue.stateStoreCkptId.get ) ) val baseCkptIds = joinCkptInfo.left.keyToNumValues.baseStateStoreCkptId.map( Array( _, - joinCkptInfo.left.valueToNumKeys.baseStateStoreCkptId.get, + joinCkptInfo.left.keyWithIndexToValue.baseStateStoreCkptId.get, joinCkptInfo.right.keyToNumValues.baseStateStoreCkptId.get, - joinCkptInfo.right.valueToNumKeys.baseStateStoreCkptId.get + joinCkptInfo.right.keyWithIndexToValue.baseStateStoreCkptId.get ) ) @@ -1158,35 +1158,79 @@ object SymmetricHashJoinStateManager { /** * Stream-stream join has 4 state stores instead of one. So it will generate 4 different - * checkpoint IDs. They are translated from each joiners' state store into an array through - * mergeStateStoreCheckpointInfo(). This function is used to read it back into individual state - * store checkpoint IDs. - * @param partitionId - * @param stateInfo - * @return + * checkpoint IDs using stateStoreCkptIds. They are translated from each joiners' state + * store into an array through mergeStateStoreCheckpointInfo(). This function is used to read + * it back into individual state store checkpoint IDs for each store. + * If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID. + * + * @param partitionId the partition ID of the state store + * @param stateStoreCkptIds the array of checkpoint IDs for all the state stores + * @param useColumnFamiliesForJoins whether virtual column families are used for the join + * + * @return the checkpoint IDs for all state stores used by this joiner */ def getStateStoreCheckpointIds( partitionId: Int, - stateInfo: StatefulOperatorStateInfo, + stateStoreCkptIds: Option[Array[Array[String]]], useColumnFamiliesForJoins: Boolean): JoinStateStoreCheckpointId = { if (useColumnFamiliesForJoins) { - val ckpt = stateInfo.stateStoreCkptIds.map(_(partitionId)).map(_.head) + val ckpt = stateStoreCkptIds.map(_(partitionId)).map(_.head) JoinStateStoreCheckpointId( - left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt), - right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt) + left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt), + right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt) ) } else { - val stateStoreCkptIds = stateInfo.stateStoreCkptIds + val stateStoreCkptIdsOpt = stateStoreCkptIds .map(_(partitionId)) .map(_.map(Option(_))) .getOrElse(Array.fill[Option[String]](4)(None)) JoinStateStoreCheckpointId( left = JoinerStateStoreCheckpointId( - keyToNumValues = stateStoreCkptIds(0), - valueToNumKeys = stateStoreCkptIds(1)), + keyToNumValues = stateStoreCkptIdsOpt(0), + keyWithIndexToValue = stateStoreCkptIdsOpt(1)), right = JoinerStateStoreCheckpointId( - keyToNumValues = stateStoreCkptIds(2), - valueToNumKeys = stateStoreCkptIds(3))) + keyToNumValues = stateStoreCkptIdsOpt(2), + keyWithIndexToValue = stateStoreCkptIdsOpt(3))) + } + } + + /** + * Stream-stream join has 4 state stores instead of one. So it will generate 4 different + * checkpoint IDs when not using virtual column families. + * This function is used to get the checkpoint ID for a specific state store + * by the name of the store, partition ID and the stateStoreCkptIds array. The expected names + * for the stores are generated by getStateStoreName(). + * If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID. + * + * @param storeName the name of the state store + * @param partitionId the partition ID of the state store + * @param stateStoreCkptIds the array of checkpoint IDs for all the state stores + * @param useColumnFamiliesForJoins whether virtual column families are used for the join + * + * @return the checkpoint ID for the specific state store, or None if not found + */ + def getStateStoreCheckpointId( + storeName: String, + partitionId: Int, + stateStoreCkptIds: Option[Array[Array[String]]], + useColumnFamiliesForJoins: Boolean = false) : Option[String] = { + if (useColumnFamiliesForJoins || storeName == StateStoreId.DEFAULT_STORE_NAME) { + stateStoreCkptIds.map(_(partitionId)).map(_.head) + } else { + val joinStateStoreCkptIds = getStateStoreCheckpointIds( + partitionId, stateStoreCkptIds, useColumnFamiliesForJoins) + + if (storeName == getStateStoreName(LeftSide, KeyToNumValuesType)) { + joinStateStoreCkptIds.left.keyToNumValues + } else if (storeName == getStateStoreName(RightSide, KeyToNumValuesType)) { + joinStateStoreCkptIds.right.keyToNumValues + } else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType)) { + joinStateStoreCkptIds.left.keyWithIndexToValue + } else if (storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) { + joinStateStoreCkptIds.right.keyWithIndexToValue + } else { + None + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index deda590645de5..d744304afb429 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -23,7 +23,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.Assertions -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row} import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow} @@ -617,6 +617,98 @@ StateDataSourceReadSuite { } } +class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceReadSuite { + override protected def newStateStoreProvider(): RocksDBStateStoreProvider = + new RocksDBStateStoreProvider + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + newStateStoreProvider().getClass.getName) + spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", + "true") + } + + // TODO: Remove this test once we allow migrations from checkpoint v1 to v2 + test("reading checkpoint v2 store with version 1 should fail") { + withTempDir { tmpDir => + val inputData = MemoryStream[(Int, Long)] + val query = getStreamStreamJoinQuery(inputData) + testStream(query)( + StartStream(checkpointLocation = tmpDir.getCanonicalPath), + AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, + StopStream + ) + + withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") { + // Verify reading state throws error when reading checkpoint v2 with version 1 + val exc = intercept[IllegalStateException] { + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.OPERATOR_ID, 0) + .load(tmpDir.getCanonicalPath) + stateDf.collect() + } + + checkError(exc.getCause.asInstanceOf[SparkThrowable], + "INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002", + Map( + "version" -> "2", + "matchVersion" -> "1")) + } + } + } + + test("check unsupported modes with checkpoint v2") { + withTempDir { tmpDir => + val inputData = MemoryStream[(Int, Long)] + val query = getStreamStreamJoinQuery(inputData) + testStream(query)( + StartStream(checkpointLocation = tmpDir.getCanonicalPath), + AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, + StopStream + ) + + // Verify reading snapshot throws error with checkpoint v2 + val exc1 = intercept[StateDataSourceInvalidOptionValue] { + val stateSnapshotDf = spark.read.format("statestore") + .option("snapshotPartitionId", 2) + .option("snapshotStartBatchId", 0) + .option("joinSide", "left") + .load(tmpDir.getCanonicalPath) + stateSnapshotDf.collect() + } + + checkError(exc1, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616", + Map( + "optionName" -> StateSourceOptions.SNAPSHOT_START_BATCH_ID, + "message" -> "Snapshot reading is currently not supported with checkpoint v2.")) + + // Verify reading change feed throws error with checkpoint v2 + val exc2 = intercept[StateDataSourceInvalidOptionValue] { + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tmpDir.getAbsolutePath) + stateDf.collect() + } + + checkError(exc2, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616", + Map( + "optionName" -> StateSourceOptions.READ_CHANGE_FEED, + "message" -> "Read change feed is currently not supported with checkpoint v2.")) + } + } +} + abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions { import testImplicits._