From eaed9686193c863856be58069afe66caad424995 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Fri, 15 Aug 2025 18:59:43 +0000 Subject: [PATCH 01/11] Initial --- .../v2/state/StateDataSource.scala | 46 ++++++++++++++++++- .../v2/state/StatePartitionReader.scala | 36 ++++++++++++++- ...StreamStreamJoinStatePartitionReader.scala | 31 ++++++++++--- .../join/StreamingSymmetricHashJoinExec.scala | 2 +- .../join/SymmetricHashJoinStateManager.scala | 14 +++--- .../state/RocksDBStateStoreProvider.scala | 2 +- .../v2/state/StateDataSourceReadSuite.scala | 8 ++++ 7 files changed, 120 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 8b35822e83fa..a762df8253d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -370,7 +370,8 @@ case class StateSourceOptions( readChangeFeedOptions: Option[ReadChangeFeedOptions], stateVarName: Option[String], readRegisteredTimers: Boolean, - flattenCollectionTypes: Boolean) { + flattenCollectionTypes: Boolean, + operatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { @@ -567,10 +568,31 @@ object StateSourceOptions extends DataSourceOptions { } } + + val startBatchId = if (fromSnapshotOptions.isDefined) { + fromSnapshotOptions.get.snapshotStartBatchId + } else if (readChangeFeedOptions.isDefined) { + readChangeFeedOptions.get.changeStartBatchId + } else { + batchId.get + } + + val operatorStateUniqueIds = getOperatorStateUniqueIds( + sparkSession, + startBatchId, + operatorId, + resolvedCpLocation) + + if (operatorStateUniqueIds.isDefined + && (fromSnapshotOptions.isDefined || readChangeFeedOptions.isDefined)) { + throw new UnsupportedOperationException( + "Reading from snapshot or reading change feed is not supported yet with Checkpoint v2.") + } + StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes) + stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds) } private def resolvedCheckpointLocation( @@ -589,6 +611,26 @@ object StateSourceOptions extends DataSourceOptions { } } + private def getOperatorStateUniqueIds( + session: SparkSession, + batchId: Long, + operatorId: Long, + checkpointLocation: String): Option[Array[Array[String]]] = { + val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog + val commitMetadata = commitLog.get(batchId) match { + case Some(commitMetadata) => commitMetadata + case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) + } + + val operatorStateUniqueIds = if (commitMetadata.stateUniqueIds.isDefined) { + Some(commitMetadata.stateUniqueIds.get(operatorId)) + } else { + None + } + + operatorStateUniqueIds + } + // Modifies options due to external data. Returns modified options. // If this is a join operator specifying a store name using state format v3, // we need to modify the options. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 6402eba868ef..7cd365faee4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} @@ -95,6 +96,31 @@ abstract class StatePartitionReaderBase( schema, "value").asInstanceOf[StructType] } + protected val getStoreUniqueId : Option[String] = { + val partitionStateUniqueIds = + partition.sourceOptions.operatorStateUniqueIds.map(_(partition.partition)) + if (partition.sourceOptions.storeName == StateStoreId.DEFAULT_STORE_NAME) { + partitionStateUniqueIds.map(_.head) + } else { + val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( + partition.partition, + partition.sourceOptions.operatorStateUniqueIds, + useColumnFamiliesForJoins = false) + + if (partition.sourceOptions.storeName == "left-keyToNumValues") { + stateStoreCheckpointIds.left.keyToNumValues + } else if (partition.sourceOptions.storeName == "left-keyWithIndexToValue") { + stateStoreCheckpointIds.left.valueToNumKeys + } else if (partition.sourceOptions.storeName == "right-keyToNumValues") { + stateStoreCheckpointIds.right.keyToNumValues + } else if (partition.sourceOptions.storeName == "right-keyWithIndexToValue") { + stateStoreCheckpointIds.right.valueToNumKeys + } else { + None + } + } + } + protected lazy val provider: StateStoreProvider = { val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) @@ -113,7 +139,9 @@ abstract class StatePartitionReaderBase( val isInternal = partition.sourceOptions.readRegisteredTimers if (useColFamilies) { - val store = provider.getStore(partition.sourceOptions.batchId + 1) + val store = provider.getStore( + partition.sourceOptions.batchId + 1, + getStoreUniqueId) require(stateStoreColFamilySchemaOpt.isDefined) val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined) @@ -171,7 +199,11 @@ class StatePartitionReader( private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { - case None => provider.getReadStore(partition.sourceOptions.batchId + 1) + case None => + provider.getReadStore( + partition.sourceOptions.batchId + 1, + getStoreUniqueId + ) case Some(fromSnapshotOptions) => if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index f1415865db24..9795b244fc6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -71,6 +71,28 @@ class StreamStreamJoinStatePartitionReader( throw StateDataSourceErrors.internalError("Unexpected join side for stream-stream read!") } + private val usesVirtualColumnFamilies = StreamStreamJoinStateHelper.usesVirtualColumnFamilies( + hadoopConf.value, + partition.sourceOptions.stateCheckpointLocation.toString, + partition.sourceOptions.operatorId) + + private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( + partition.partition, + partition.sourceOptions.operatorStateUniqueIds, + usesVirtualColumnFamilies) + + private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) { + stateStoreCheckpointIds.left.keyToNumValues + } else { + stateStoreCheckpointIds.right.keyToNumValues + } + + private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) { + stateStoreCheckpointIds.left.valueToNumKeys + } else { + stateStoreCheckpointIds.right.valueToNumKeys + } + /* * This is to handle the difference of schema across state format versions. The major difference * is whether we have added new field(s) in addition to the fields from input schema. @@ -85,10 +107,7 @@ class StreamStreamJoinStatePartitionReader( // column from the value schema to get the actual fields. if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) { // If checkpoint is using one store and virtual column families, version is 3 - if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies( - hadoopConf.value, - partition.sourceOptions.stateCheckpointLocation.toString, - partition.sourceOptions.operatorId)) { + if (usesVirtualColumnFamilies) { (valueSchema.dropRight(1), 3) } else { (valueSchema.dropRight(1), 2) @@ -130,8 +149,8 @@ class StreamStreamJoinStatePartitionReader( storeConf = storeConf, hadoopConf = hadoopConf.value, partitionId = partition.partition, - keyToNumValuesStateStoreCkptId = None, - keyWithIndexToValueStateStoreCkptId = None, + keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId, + keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId, formatVersion, skippedNullValueCount = None, useStateStoreCoordinator = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index f424d2892dfa..28024598ac25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -351,7 +351,7 @@ case class StreamingSymmetricHashJoinExec( assert(stateInfo.isDefined, "State info not defined") val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( - partitionId, stateInfo.get, useVirtualColumnFamilies) + partitionId, stateInfo.get.stateStoreCkptIds, useVirtualColumnFamilies) val inputSchema = left.output ++ right.output val postJoinFilter = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 4ba6dcced533..3068e5bc58af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -1167,26 +1167,26 @@ object SymmetricHashJoinStateManager { */ def getStateStoreCheckpointIds( partitionId: Int, - stateInfo: StatefulOperatorStateInfo, + stateStoreCkptIds: Option[Array[Array[String]]], useColumnFamiliesForJoins: Boolean): JoinStateStoreCheckpointId = { if (useColumnFamiliesForJoins) { - val ckpt = stateInfo.stateStoreCkptIds.map(_(partitionId)).map(_.head) + val ckpt = stateStoreCkptIds.map(_(partitionId)).map(_.head) JoinStateStoreCheckpointId( left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt), right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt) ) } else { - val stateStoreCkptIds = stateInfo.stateStoreCkptIds + val stateStoreCkptIdsOpt = stateStoreCkptIds .map(_(partitionId)) .map(_.map(Option(_))) .getOrElse(Array.fill[Option[String]](4)(None)) JoinStateStoreCheckpointId( left = JoinerStateStoreCheckpointId( - keyToNumValues = stateStoreCkptIds(0), - valueToNumKeys = stateStoreCkptIds(1)), + keyToNumValues = stateStoreCkptIdsOpt(0), + valueToNumKeys = stateStoreCkptIdsOpt(1)), right = JoinerStateStoreCheckpointId( - keyToNumValues = stateStoreCkptIds(2), - valueToNumKeys = stateStoreCkptIds(3))) + keyToNumValues = stateStoreCkptIdsOpt(2), + valueToNumKeys = stateStoreCkptIdsOpt(3))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 7098fd41f402..095d69f1be04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -690,7 +690,7 @@ private[sql] class RocksDBStateStoreProvider rocksDB.load( version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + stateStoreCkptId = uniqueId, readOnly = readOnly) // Create or reuse store instance diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index deda590645de..8d50a4da3f83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -617,6 +617,14 @@ StateDataSourceReadSuite { } } +class RocksDBWithCheckpointV2StateDataSourceReaderSuite + extends RocksDBWithChangelogCheckpointStateDataSourceReaderSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + } +} + abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions { import testImplicits._ From 188491df6c27b5fca410fddc458ed41d56d05da4 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Mon, 18 Aug 2025 05:26:10 +0000 Subject: [PATCH 02/11] Add tests for invalid options --- .../v2/state/StateDataSource.scala | 15 ++++-- .../v2/state/StateDataSourceReadSuite.scala | 54 ++++++++++++++++++- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index a762df8253d1..e54b570565b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -583,10 +583,17 @@ object StateSourceOptions extends DataSourceOptions { operatorId, resolvedCpLocation) - if (operatorStateUniqueIds.isDefined - && (fromSnapshotOptions.isDefined || readChangeFeedOptions.isDefined)) { - throw new UnsupportedOperationException( - "Reading from snapshot or reading change feed is not supported yet with Checkpoint v2.") + if (operatorStateUniqueIds.isDefined) { + if (fromSnapshotOptions.isDefined) { + throw StateDataSourceErrors.invalidOptionValue( + SNAPSHOT_START_BATCH_ID, + "Snapshot reading is currently not supported with checkpoint v2.") + } + if (readChangeFeedOptions.isDefined) { + throw StateDataSourceErrors.invalidOptionValue( + READ_CHANGE_FEED, + "Read change feed is currently not supported with checkpoint v2.") + } } StateSourceOptions( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 8d50a4da3f83..937bee326735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -617,11 +617,61 @@ StateDataSourceReadSuite { } } -class RocksDBWithCheckpointV2StateDataSourceReaderSuite - extends RocksDBWithChangelogCheckpointStateDataSourceReaderSuite { +class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceReadSuite { + override protected def newStateStoreProvider(): RocksDBStateStoreProvider = + new RocksDBStateStoreProvider + + import testImplicits._ + override def beforeAll(): Unit = { super.beforeAll() spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + newStateStoreProvider().getClass.getName) + } + + test("check unsupported modes with checkpoint v2") { + withTempDir { tmpDir => + val inputData = MemoryStream[(Int, Long)] + val query = getStreamStreamJoinQuery(inputData) + testStream(query)( + StartStream(checkpointLocation = tmpDir.getCanonicalPath), + AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, + StopStream + ) + + // Verify reading snapshot throws error with checkpoint v2 + val exc1 = intercept[StateDataSourceInvalidOptionValue] { + val stateSnapshotDf = spark.read.format("statestore") + .option("snapshotPartitionId", 2) + .option("snapshotStartBatchId", 0) + .option("joinSide", "left") + .load(tmpDir.getCanonicalPath) + stateSnapshotDf.collect() + } + + checkError(exc1, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616", + Map( + "optionName" -> StateSourceOptions.SNAPSHOT_START_BATCH_ID, + "message" -> "Snapshot reading is currently not supported with checkpoint v2.")) + + // Verify reading change feed throws error with checkpoint v2 + val exc2 = intercept[StateDataSourceInvalidOptionValue] { + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tmpDir.getAbsolutePath) + stateDf.collect() + } + + checkError(exc2, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616", + Map( + "optionName" -> StateSourceOptions.READ_CHANGE_FEED, + "message" -> "Read change feed is currently not supported with checkpoint v2.")) + } } } From dd1eb2a1c7261ab1b585b1291eb8721f58eff6f9 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Mon, 18 Aug 2025 17:04:47 +0000 Subject: [PATCH 03/11] move store names to constants --- .../v2/state/StatePartitionReader.scala | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 7cd365faee4a..c76d59290d63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -29,6 +29,16 @@ import org.apache.spark.sql.types.{NullType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{NextIterator, SerializableConfiguration} +/** + * Constants for store names used in Stream-Stream joins. + */ +object StatePartitionReaderStoreNames { + val LEFT_KEY_TO_NUM_VALUES_STORE = "left-keyToNumValues" + val LEFT_KEY_WITH_INDEX_TO_VALUE_STORE = "left-keyWithIndexToValue" + val RIGHT_KEY_TO_NUM_VALUES_STORE = "right-keyToNumValues" + val RIGHT_KEY_WITH_INDEX_TO_VALUE_STORE = "right-keyWithIndexToValue" +} + /** * An implementation of [[PartitionReaderFactory]] for State data source. This is used to support * general read from a state store instance, rather than specific to the operator. @@ -107,16 +117,16 @@ abstract class StatePartitionReaderBase( partition.sourceOptions.operatorStateUniqueIds, useColumnFamiliesForJoins = false) - if (partition.sourceOptions.storeName == "left-keyToNumValues") { - stateStoreCheckpointIds.left.keyToNumValues - } else if (partition.sourceOptions.storeName == "left-keyWithIndexToValue") { - stateStoreCheckpointIds.left.valueToNumKeys - } else if (partition.sourceOptions.storeName == "right-keyToNumValues") { - stateStoreCheckpointIds.right.keyToNumValues - } else if (partition.sourceOptions.storeName == "right-keyWithIndexToValue") { - stateStoreCheckpointIds.right.valueToNumKeys - } else { - None + partition.sourceOptions.storeName match { + case StatePartitionReaderStoreNames.LEFT_KEY_TO_NUM_VALUES_STORE => + stateStoreCheckpointIds.left.keyToNumValues + case StatePartitionReaderStoreNames.LEFT_KEY_WITH_INDEX_TO_VALUE_STORE => + stateStoreCheckpointIds.left.valueToNumKeys + case StatePartitionReaderStoreNames.RIGHT_KEY_TO_NUM_VALUES_STORE => + stateStoreCheckpointIds.right.keyToNumValues + case StatePartitionReaderStoreNames.RIGHT_KEY_WITH_INDEX_TO_VALUE_STORE => + stateStoreCheckpointIds.right.valueToNumKeys + case _ => None } } } From e2104d600020e8259b73b4dae5f506066b958120 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Mon, 18 Aug 2025 21:42:32 +0000 Subject: [PATCH 04/11] PR fixes --- .../sql/execution/datasources/v2/state/StateDataSource.scala | 4 +--- .../datasources/v2/state/StateDataSourceReadSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index e54b570565b7..874638f29c99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -629,13 +629,11 @@ object StateSourceOptions extends DataSourceOptions { case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) } - val operatorStateUniqueIds = if (commitMetadata.stateUniqueIds.isDefined) { + if (commitMetadata.stateUniqueIds.isDefined) { Some(commitMetadata.stateUniqueIds.get(operatorId)) } else { None } - - operatorStateUniqueIds } // Modifies options due to external data. Returns modified options. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 937bee326735..7be66ac2970a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -628,6 +628,8 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, newStateStoreProvider().getClass.getName) + spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", + "true") } test("check unsupported modes with checkpoint v2") { From 03ca2ff1b467e72e4d1f90affe29d8f14d91a1e5 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Wed, 20 Aug 2025 16:31:06 +0000 Subject: [PATCH 05/11] PR fixes --- .../v2/state/StateDataSource.scala | 6 +-- .../v2/state/StatePartitionReader.scala | 36 ++--------------- .../join/SymmetricHashJoinStateManager.scala | 39 ++++++++++++++++++- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 874638f29c99..3d4f230c4443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -629,11 +629,7 @@ object StateSourceOptions extends DataSourceOptions { case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) } - if (commitMetadata.stateUniqueIds.isDefined) { - Some(commitMetadata.stateUniqueIds.get(operatorId)) - } else { - None - } + commitMetadata.stateUniqueIds.flatMap(_.get(operatorId)) } // Modifies options due to external data. Returns modified options. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index c76d59290d63..ebef6e3dac55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -29,16 +29,6 @@ import org.apache.spark.sql.types.{NullType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{NextIterator, SerializableConfiguration} -/** - * Constants for store names used in Stream-Stream joins. - */ -object StatePartitionReaderStoreNames { - val LEFT_KEY_TO_NUM_VALUES_STORE = "left-keyToNumValues" - val LEFT_KEY_WITH_INDEX_TO_VALUE_STORE = "left-keyWithIndexToValue" - val RIGHT_KEY_TO_NUM_VALUES_STORE = "right-keyToNumValues" - val RIGHT_KEY_WITH_INDEX_TO_VALUE_STORE = "right-keyWithIndexToValue" -} - /** * An implementation of [[PartitionReaderFactory]] for State data source. This is used to support * general read from a state store instance, rather than specific to the operator. @@ -107,28 +97,10 @@ abstract class StatePartitionReaderBase( } protected val getStoreUniqueId : Option[String] = { - val partitionStateUniqueIds = - partition.sourceOptions.operatorStateUniqueIds.map(_(partition.partition)) - if (partition.sourceOptions.storeName == StateStoreId.DEFAULT_STORE_NAME) { - partitionStateUniqueIds.map(_.head) - } else { - val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( - partition.partition, - partition.sourceOptions.operatorStateUniqueIds, - useColumnFamiliesForJoins = false) - - partition.sourceOptions.storeName match { - case StatePartitionReaderStoreNames.LEFT_KEY_TO_NUM_VALUES_STORE => - stateStoreCheckpointIds.left.keyToNumValues - case StatePartitionReaderStoreNames.LEFT_KEY_WITH_INDEX_TO_VALUE_STORE => - stateStoreCheckpointIds.left.valueToNumKeys - case StatePartitionReaderStoreNames.RIGHT_KEY_TO_NUM_VALUES_STORE => - stateStoreCheckpointIds.right.keyToNumValues - case StatePartitionReaderStoreNames.RIGHT_KEY_WITH_INDEX_TO_VALUE_STORE => - stateStoreCheckpointIds.right.valueToNumKeys - case _ => None - } - } + SymmetricHashJoinStateManager.getStateStoreCheckpointId( + storeName = partition.sourceOptions.storeName, + partitionId = partition.partition, + stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds) } protected lazy val provider: StateStoreProvider = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 3068e5bc58af..9a7d4aa4a99c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -1162,7 +1162,8 @@ object SymmetricHashJoinStateManager { * mergeStateStoreCheckpointInfo(). This function is used to read it back into individual state * store checkpoint IDs. * @param partitionId - * @param stateInfo + * @param stateStoreCkptIds + * @param useColumnFamiliesForJoins * @return */ def getStateStoreCheckpointIds( @@ -1190,6 +1191,42 @@ object SymmetricHashJoinStateManager { } } + /** + * Stream-stream join has 4 state stores instead of one. So it will generate 4 different + * checkpoint IDs when not using virtual column families. + * This function is used to get the checkpoint ID for a specific state store + * by the name of the store, partition ID and the checkpoint IDs array. + * @param storeName + * @param partitionId + * @param stateStoreCkptIds + * @param useColumnFamiliesForJoins + * @return + */ + def getStateStoreCheckpointId( + storeName: String, + partitionId: Int, + stateStoreCkptIds: Option[Array[Array[String]]], + useColumnFamiliesForJoins: Boolean = false) : Option[String] = { + if (useColumnFamiliesForJoins || storeName == StateStoreId.DEFAULT_STORE_NAME) { + stateStoreCkptIds.map(_(partitionId)).map(_.head) + } else { + val joinStateStoreCkptIds = getStateStoreCheckpointIds( + partitionId, stateStoreCkptIds, useColumnFamiliesForJoins) + + if (storeName == getStateStoreName(LeftSide, KeyToNumValuesType)) { + joinStateStoreCkptIds.left.keyToNumValues + } else if (storeName == getStateStoreName(RightSide, KeyToNumValuesType)) { + joinStateStoreCkptIds.right.keyToNumValues + } else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType)) { + joinStateStoreCkptIds.left.valueToNumKeys + } else if (storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) { + joinStateStoreCkptIds.right.valueToNumKeys + } else { + None + } + } + } + private[join] sealed trait StateStoreType private[join] case object KeyToNumValuesType extends StateStoreType { From a3d31616543cdfaac091da8b739feb65fcd371c9 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Fri, 22 Aug 2025 23:44:10 +0000 Subject: [PATCH 06/11] PR fixes --- .../v2/state/StateDataSource.scala | 1 - ...StreamStreamJoinStatePartitionReader.scala | 4 ++-- .../join/StreamingSymmetricHashJoinExec.scala | 4 ++-- .../StreamingSymmetricHashJoinHelper.scala | 6 +++--- .../join/SymmetricHashJoinStateManager.scala | 20 +++++++++---------- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 3d4f230c4443..828c06ab834a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -568,7 +568,6 @@ object StateSourceOptions extends DataSourceOptions { } } - val startBatchId = if (fromSnapshotOptions.isDefined) { fromSnapshotOptions.get.snapshotStartBatchId } else if (readChangeFeedOptions.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 9795b244fc6c..0f8a3b3b609f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -88,9 +88,9 @@ class StreamStreamJoinStatePartitionReader( } private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) { - stateStoreCheckpointIds.left.valueToNumKeys + stateStoreCheckpointIds.left.keyWithIndexToValue } else { - stateStoreCheckpointIds.right.valueToNumKeys + stateStoreCheckpointIds.right.keyWithIndexToValue } /* diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 28024598ac25..ef37185ce416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -363,12 +363,12 @@ case class StreamingSymmetricHashJoinExec( new OneSideHashJoiner( LeftSide, left.output, leftKeys, leftInputIter, condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId, - checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys, + checkpointIds.left.keyToNumValues, checkpointIds.left.keyWithIndexToValue, skippedNullValueCount, joinStateManagerStoreGenerator), new OneSideHashJoiner( RightSide, right.output, rightKeys, rightInputIter, condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId, - checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys, + checkpointIds.right.keyToNumValues, checkpointIds.right.keyWithIndexToValue, skippedNullValueCount, joinStateManagerStoreGenerator)) // Join one side input using the other side's buffered/state rows. Here is how it is done. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala index 6f02a17efe34..7b02a43cd5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala @@ -324,15 +324,15 @@ object StreamingSymmetricHashJoinHelper extends Logging { case class JoinerStateStoreCkptInfo( keyToNumValues: StateStoreCheckpointInfo, - valueToNumKeys: StateStoreCheckpointInfo) + keyWithIndexToValue: StateStoreCheckpointInfo) case class JoinStateStoreCkptInfo( left: JoinerStateStoreCkptInfo, right: JoinerStateStoreCkptInfo) case class JoinerStateStoreCheckpointId( - keyToNumValues: Option[String], - valueToNumKeys: Option[String]) + keyToNumValues: Option[String], + keyWithIndexToValue: Option[String]) case class JoinStateStoreCheckpointId( left: JoinerStateStoreCheckpointId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 9a7d4aa4a99c..a0b64fcdf97e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -1135,17 +1135,17 @@ object SymmetricHashJoinStateManager { val ckptIds = joinCkptInfo.left.keyToNumValues.stateStoreCkptId.map( Array( _, - joinCkptInfo.left.valueToNumKeys.stateStoreCkptId.get, + joinCkptInfo.left.keyWithIndexToValue.stateStoreCkptId.get, joinCkptInfo.right.keyToNumValues.stateStoreCkptId.get, - joinCkptInfo.right.valueToNumKeys.stateStoreCkptId.get + joinCkptInfo.right.keyWithIndexToValue.stateStoreCkptId.get ) ) val baseCkptIds = joinCkptInfo.left.keyToNumValues.baseStateStoreCkptId.map( Array( _, - joinCkptInfo.left.valueToNumKeys.baseStateStoreCkptId.get, + joinCkptInfo.left.keyWithIndexToValue.baseStateStoreCkptId.get, joinCkptInfo.right.keyToNumValues.baseStateStoreCkptId.get, - joinCkptInfo.right.valueToNumKeys.baseStateStoreCkptId.get + joinCkptInfo.right.keyWithIndexToValue.baseStateStoreCkptId.get ) ) @@ -1173,8 +1173,8 @@ object SymmetricHashJoinStateManager { if (useColumnFamiliesForJoins) { val ckpt = stateStoreCkptIds.map(_(partitionId)).map(_.head) JoinStateStoreCheckpointId( - left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt), - right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt) + left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt), + right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt) ) } else { val stateStoreCkptIdsOpt = stateStoreCkptIds @@ -1184,10 +1184,10 @@ object SymmetricHashJoinStateManager { JoinStateStoreCheckpointId( left = JoinerStateStoreCheckpointId( keyToNumValues = stateStoreCkptIdsOpt(0), - valueToNumKeys = stateStoreCkptIdsOpt(1)), + keyWithIndexToValue = stateStoreCkptIdsOpt(1)), right = JoinerStateStoreCheckpointId( keyToNumValues = stateStoreCkptIdsOpt(2), - valueToNumKeys = stateStoreCkptIdsOpt(3))) + keyWithIndexToValue = stateStoreCkptIdsOpt(3))) } } @@ -1218,9 +1218,9 @@ object SymmetricHashJoinStateManager { } else if (storeName == getStateStoreName(RightSide, KeyToNumValuesType)) { joinStateStoreCkptIds.right.keyToNumValues } else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType)) { - joinStateStoreCkptIds.left.valueToNumKeys + joinStateStoreCkptIds.left.keyWithIndexToValue } else if (storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) { - joinStateStoreCkptIds.right.valueToNumKeys + joinStateStoreCkptIds.right.keyWithIndexToValue } else { None } From c286e975a69723ac9d4a707963227a04a8075ca3 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Sat, 23 Aug 2025 16:42:56 +0000 Subject: [PATCH 07/11] fix forget --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 095d69f1be04..7098fd41f402 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -690,7 +690,7 @@ private[sql] class RocksDBStateStoreProvider rocksDB.load( version, - stateStoreCkptId = uniqueId, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, readOnly = readOnly) // Create or reuse store instance From 3d1dfb2ab62d9546dec8f5e46ba26f69aa969c2d Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Tue, 26 Aug 2025 23:15:47 +0000 Subject: [PATCH 08/11] Fix method comments --- .../stateful/join/SymmetricHashJoinStateManager.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index a0b64fcdf97e..a85e0ea2c445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -1158,9 +1158,10 @@ object SymmetricHashJoinStateManager { /** * Stream-stream join has 4 state stores instead of one. So it will generate 4 different - * checkpoint IDs. They are translated from each joiners' state store into an array through - * mergeStateStoreCheckpointInfo(). This function is used to read it back into individual state - * store checkpoint IDs. + * checkpoint IDs using stateStoreCkptIds. They are translated from each joiners' state + * store into an array through mergeStateStoreCheckpointInfo(). This function is used to read + * it back into individual state store checkpoint IDs for each store. + * If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID. * @param partitionId * @param stateStoreCkptIds * @param useColumnFamiliesForJoins @@ -1195,7 +1196,9 @@ object SymmetricHashJoinStateManager { * Stream-stream join has 4 state stores instead of one. So it will generate 4 different * checkpoint IDs when not using virtual column families. * This function is used to get the checkpoint ID for a specific state store - * by the name of the store, partition ID and the checkpoint IDs array. + * by the name of the store, partition ID and the stateStoreCkptIds array. The expected names + * for the stores are generated by getStateStoreName(). + * If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID. * @param storeName * @param partitionId * @param stateStoreCkptIds From 9662334ad7ba6b12c31072c2c5dcb9731ccab3cf Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Wed, 27 Aug 2025 00:46:40 +0000 Subject: [PATCH 09/11] Add test --- .../v2/state/StateDataSourceReadSuite.scala | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 7be66ac2970a..8c66dd3d1461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -23,7 +23,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.Assertions -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row} import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow} @@ -632,6 +632,39 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR "true") } + // TODO: Remove this test once we allow migrations from checkpoint v1 to v2 + test("reading checkpoint v2 store with version 1 should fail") { + withTempDir { tmpDir => + val inputData = MemoryStream[(Int, Long)] + val query = getStreamStreamJoinQuery(inputData) + testStream(query)( + StartStream(checkpointLocation = tmpDir.getCanonicalPath), + AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, + StopStream + ) + + // Set the checkpoint version to 1 + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 1) + + // Verify reading state throws error when reading checkpoint v2 with version 1 + val exc = intercept[IllegalStateException] { + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.OPERATOR_ID, 0) + .load(tmpDir.getCanonicalPath) + stateDf.collect() + } + + checkError(exc.getCause.asInstanceOf[SparkThrowable], + "INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002", + Map( + "version" -> "2", + "matchVersion" -> "1")) + } + } + test("check unsupported modes with checkpoint v2") { withTempDir { tmpDir => val inputData = MemoryStream[(Int, Long)] From cd6aac31a07763b89c983da5f0824de48550bc5e Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Wed, 27 Aug 2025 01:10:43 +0000 Subject: [PATCH 10/11] Fix param comments --- .../join/SymmetricHashJoinStateManager.scala | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index a85e0ea2c445..c0965747722e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -1162,10 +1162,12 @@ object SymmetricHashJoinStateManager { * store into an array through mergeStateStoreCheckpointInfo(). This function is used to read * it back into individual state store checkpoint IDs for each store. * If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID. - * @param partitionId - * @param stateStoreCkptIds - * @param useColumnFamiliesForJoins - * @return + * + * @param partitionId the partition ID of the state store + * @param stateStoreCkptIds the array of checkpoint IDs for all the state stores + * @param useColumnFamiliesForJoins whether virtual column families are used for the join + * + * @return the checkpoint IDs for all state stores used by this joiner */ def getStateStoreCheckpointIds( partitionId: Int, @@ -1199,11 +1201,13 @@ object SymmetricHashJoinStateManager { * by the name of the store, partition ID and the stateStoreCkptIds array. The expected names * for the stores are generated by getStateStoreName(). * If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID. - * @param storeName - * @param partitionId - * @param stateStoreCkptIds - * @param useColumnFamiliesForJoins - * @return + * + * @param storeName the name of the state store + * @param partitionId the partition ID of the state store + * @param stateStoreCkptIds the array of checkpoint IDs for all the state stores + * @param useColumnFamiliesForJoins whether virtual column families are used for the join + * + * @return the checkpoint ID for the specific state store, or None if not found */ def getStateStoreCheckpointId( storeName: String, From 6b3818187e1803afbb4bee9d9adc178485a07167 Mon Sep 17 00:00:00 2001 From: Dylan Wong Date: Wed, 27 Aug 2025 13:50:46 +0000 Subject: [PATCH 11/11] Fix test --- .../v2/state/StateDataSourceReadSuite.scala | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 8c66dd3d1461..d744304afb42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -645,23 +645,22 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR StopStream ) - // Set the checkpoint version to 1 - spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 1) + withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") { + // Verify reading state throws error when reading checkpoint v2 with version 1 + val exc = intercept[IllegalStateException] { + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.OPERATOR_ID, 0) + .load(tmpDir.getCanonicalPath) + stateDf.collect() + } - // Verify reading state throws error when reading checkpoint v2 with version 1 - val exc = intercept[IllegalStateException] { - val stateDf = spark.read.format("statestore") - .option(StateSourceOptions.BATCH_ID, 0) - .option(StateSourceOptions.OPERATOR_ID, 0) - .load(tmpDir.getCanonicalPath) - stateDf.collect() + checkError(exc.getCause.asInstanceOf[SparkThrowable], + "INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002", + Map( + "version" -> "2", + "matchVersion" -> "1")) } - - checkError(exc.getCause.asInstanceOf[SparkThrowable], - "INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002", - Map( - "version" -> "2", - "matchVersion" -> "1")) } }