Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ case class StateSourceOptions(
readChangeFeedOptions: Option[ReadChangeFeedOptions],
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean) {
flattenCollectionTypes: Boolean,
operatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand Down Expand Up @@ -567,10 +568,37 @@ object StateSourceOptions extends DataSourceOptions {
}
}

val startBatchId = if (fromSnapshotOptions.isDefined) {
fromSnapshotOptions.get.snapshotStartBatchId
} else if (readChangeFeedOptions.isDefined) {
readChangeFeedOptions.get.changeStartBatchId
} else {
batchId.get
}

val operatorStateUniqueIds = getOperatorStateUniqueIds(
sparkSession,
startBatchId,
operatorId,
resolvedCpLocation)

if (operatorStateUniqueIds.isDefined) {
if (fromSnapshotOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID,
"Snapshot reading is currently not supported with checkpoint v2.")
}
if (readChangeFeedOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
READ_CHANGE_FEED,
"Read change feed is currently not supported with checkpoint v2.")
}
}

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes)
stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds)
}

private def resolvedCheckpointLocation(
Expand All @@ -589,6 +617,20 @@ object StateSourceOptions extends DataSourceOptions {
}
}

private def getOperatorStateUniqueIds(
session: SparkSession,
batchId: Long,
operatorId: Long,
checkpointLocation: String): Option[Array[Array[String]]] = {
val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog
val commitMetadata = commitLog.get(batchId) match {
case Some(commitMetadata) => commitMetadata
case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation)
}

commitMetadata.stateUniqueIds.flatMap(_.get(operatorId))
}

// Modifies options due to external data. Returns modified options.
// If this is a join operator specifying a store name using state format v3,
// we need to modify the options.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
Expand Down Expand Up @@ -95,6 +96,13 @@ abstract class StatePartitionReaderBase(
schema, "value").asInstanceOf[StructType]
}

protected val getStoreUniqueId : Option[String] = {
SymmetricHashJoinStateManager.getStateStoreCheckpointId(
storeName = partition.sourceOptions.storeName,
partitionId = partition.partition,
stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds)
}

protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
Expand All @@ -113,7 +121,9 @@ abstract class StatePartitionReaderBase(
val isInternal = partition.sourceOptions.readRegisteredTimers

if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
val store = provider.getStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId)
require(stateStoreColFamilySchemaOpt.isDefined)
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
Expand Down Expand Up @@ -171,7 +181,11 @@ class StatePartitionReader(

private lazy val store: ReadStateStore = {
partition.sourceOptions.fromSnapshotOptions match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)
case None =>
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId
)

case Some(fromSnapshotOptions) =>
if (!provider.isInstanceOf[SupportsFineGrainedReplay]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ class StreamStreamJoinStatePartitionReader(
throw StateDataSourceErrors.internalError("Unexpected join side for stream-stream read!")
}

private val usesVirtualColumnFamilies = StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)

private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.operatorStateUniqueIds,
usesVirtualColumnFamilies)

private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyToNumValues
} else {
stateStoreCheckpointIds.right.keyToNumValues
}

private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyWithIndexToValue
} else {
stateStoreCheckpointIds.right.keyWithIndexToValue
}

/*
* This is to handle the difference of schema across state format versions. The major difference
* is whether we have added new field(s) in addition to the fields from input schema.
Expand All @@ -85,10 +107,7 @@ class StreamStreamJoinStatePartitionReader(
// column from the value schema to get the actual fields.
if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) {
// If checkpoint is using one store and virtual column families, version is 3
if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)) {
if (usesVirtualColumnFamilies) {
(valueSchema.dropRight(1), 3)
} else {
(valueSchema.dropRight(1), 2)
Expand Down Expand Up @@ -130,8 +149,8 @@ class StreamStreamJoinStatePartitionReader(
storeConf = storeConf,
hadoopConf = hadoopConf.value,
partitionId = partition.partition,
keyToNumValuesStateStoreCkptId = None,
keyWithIndexToValueStateStoreCkptId = None,
keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ case class StreamingSymmetricHashJoinExec(

assert(stateInfo.isDefined, "State info not defined")
val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partitionId, stateInfo.get, useVirtualColumnFamilies)
partitionId, stateInfo.get.stateStoreCkptIds, useVirtualColumnFamilies)

val inputSchema = left.output ++ right.output
val postJoinFilter =
Expand All @@ -363,12 +363,12 @@ case class StreamingSymmetricHashJoinExec(
new OneSideHashJoiner(
LeftSide, left.output, leftKeys, leftInputIter,
condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId,
checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys,
checkpointIds.left.keyToNumValues, checkpointIds.left.keyWithIndexToValue,
skippedNullValueCount, joinStateManagerStoreGenerator),
new OneSideHashJoiner(
RightSide, right.output, rightKeys, rightInputIter,
condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId,
checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys,
checkpointIds.right.keyToNumValues, checkpointIds.right.keyWithIndexToValue,
skippedNullValueCount, joinStateManagerStoreGenerator))

// Join one side input using the other side's buffered/state rows. Here is how it is done.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ object StreamingSymmetricHashJoinHelper extends Logging {

case class JoinerStateStoreCkptInfo(
keyToNumValues: StateStoreCheckpointInfo,
valueToNumKeys: StateStoreCheckpointInfo)
keyWithIndexToValue: StateStoreCheckpointInfo)

case class JoinStateStoreCkptInfo(
left: JoinerStateStoreCkptInfo,
right: JoinerStateStoreCkptInfo)

case class JoinerStateStoreCheckpointId(
keyToNumValues: Option[String],
valueToNumKeys: Option[String])
keyToNumValues: Option[String],
keyWithIndexToValue: Option[String])

case class JoinStateStoreCheckpointId(
left: JoinerStateStoreCheckpointId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1135,17 +1135,17 @@ object SymmetricHashJoinStateManager {
val ckptIds = joinCkptInfo.left.keyToNumValues.stateStoreCkptId.map(
Array(
_,
joinCkptInfo.left.valueToNumKeys.stateStoreCkptId.get,
joinCkptInfo.left.keyWithIndexToValue.stateStoreCkptId.get,
joinCkptInfo.right.keyToNumValues.stateStoreCkptId.get,
joinCkptInfo.right.valueToNumKeys.stateStoreCkptId.get
joinCkptInfo.right.keyWithIndexToValue.stateStoreCkptId.get
)
)
val baseCkptIds = joinCkptInfo.left.keyToNumValues.baseStateStoreCkptId.map(
Array(
_,
joinCkptInfo.left.valueToNumKeys.baseStateStoreCkptId.get,
joinCkptInfo.left.keyWithIndexToValue.baseStateStoreCkptId.get,
joinCkptInfo.right.keyToNumValues.baseStateStoreCkptId.get,
joinCkptInfo.right.valueToNumKeys.baseStateStoreCkptId.get
joinCkptInfo.right.keyWithIndexToValue.baseStateStoreCkptId.get
)
)

Expand All @@ -1158,35 +1158,79 @@ object SymmetricHashJoinStateManager {

/**
* Stream-stream join has 4 state stores instead of one. So it will generate 4 different
* checkpoint IDs. They are translated from each joiners' state store into an array through
* mergeStateStoreCheckpointInfo(). This function is used to read it back into individual state
* store checkpoint IDs.
* @param partitionId
* @param stateInfo
* @return
* checkpoint IDs using stateStoreCkptIds. They are translated from each joiners' state
* store into an array through mergeStateStoreCheckpointInfo(). This function is used to read
* it back into individual state store checkpoint IDs for each store.
* If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID.
*
* @param partitionId the partition ID of the state store
* @param stateStoreCkptIds the array of checkpoint IDs for all the state stores
* @param useColumnFamiliesForJoins whether virtual column families are used for the join
*
* @return the checkpoint IDs for all state stores used by this joiner
*/
def getStateStoreCheckpointIds(
partitionId: Int,
stateInfo: StatefulOperatorStateInfo,
stateStoreCkptIds: Option[Array[Array[String]]],
useColumnFamiliesForJoins: Boolean): JoinStateStoreCheckpointId = {
if (useColumnFamiliesForJoins) {
val ckpt = stateInfo.stateStoreCkptIds.map(_(partitionId)).map(_.head)
val ckpt = stateStoreCkptIds.map(_(partitionId)).map(_.head)
JoinStateStoreCheckpointId(
left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt),
right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt)
left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt),
right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt)
)
} else {
val stateStoreCkptIds = stateInfo.stateStoreCkptIds
val stateStoreCkptIdsOpt = stateStoreCkptIds
.map(_(partitionId))
.map(_.map(Option(_)))
.getOrElse(Array.fill[Option[String]](4)(None))
JoinStateStoreCheckpointId(
left = JoinerStateStoreCheckpointId(
keyToNumValues = stateStoreCkptIds(0),
valueToNumKeys = stateStoreCkptIds(1)),
keyToNumValues = stateStoreCkptIdsOpt(0),
keyWithIndexToValue = stateStoreCkptIdsOpt(1)),
right = JoinerStateStoreCheckpointId(
keyToNumValues = stateStoreCkptIds(2),
valueToNumKeys = stateStoreCkptIds(3)))
keyToNumValues = stateStoreCkptIdsOpt(2),
keyWithIndexToValue = stateStoreCkptIdsOpt(3)))
}
}

/**
* Stream-stream join has 4 state stores instead of one. So it will generate 4 different
* checkpoint IDs when not using virtual column families.
* This function is used to get the checkpoint ID for a specific state store
* by the name of the store, partition ID and the stateStoreCkptIds array. The expected names
* for the stores are generated by getStateStoreName().
* If useColumnFamiliesForJoins is true, then it will always return the first checkpoint ID.
*
* @param storeName the name of the state store
* @param partitionId the partition ID of the state store
* @param stateStoreCkptIds the array of checkpoint IDs for all the state stores
* @param useColumnFamiliesForJoins whether virtual column families are used for the join
*
* @return the checkpoint ID for the specific state store, or None if not found
*/
def getStateStoreCheckpointId(
storeName: String,
partitionId: Int,
stateStoreCkptIds: Option[Array[Array[String]]],
useColumnFamiliesForJoins: Boolean = false) : Option[String] = {
if (useColumnFamiliesForJoins || storeName == StateStoreId.DEFAULT_STORE_NAME) {
stateStoreCkptIds.map(_(partitionId)).map(_.head)
} else {
val joinStateStoreCkptIds = getStateStoreCheckpointIds(
partitionId, stateStoreCkptIds, useColumnFamiliesForJoins)

if (storeName == getStateStoreName(LeftSide, KeyToNumValuesType)) {
joinStateStoreCkptIds.left.keyToNumValues
} else if (storeName == getStateStoreName(RightSide, KeyToNumValuesType)) {
joinStateStoreCkptIds.right.keyToNumValues
} else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType)) {
joinStateStoreCkptIds.left.keyWithIndexToValue
} else if (storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) {
joinStateStoreCkptIds.right.keyWithIndexToValue
} else {
None
}
}
}

Expand Down
Loading