Skip to content
Closed
19 changes: 19 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@
"The change log writer version cannot be <version>."
]
},
"INVALID_CHECKPOINT_LINEAGE" : {
"message" : [
"Invalid checkpoint lineage: <lineage>. <message>"
]
},
"KEY_ROW_FORMAT_VALIDATION_FAILURE" : {
"message" : [
"<msg>"
Expand Down Expand Up @@ -5168,6 +5173,12 @@
],
"sqlState" : "42802"
},
"STATE_STORE_CHECKPOINT_IDS_NOT_SUPPORTED" : {
"message" : [
"<msg>"
],
"sqlState" : "KD002"
},
"STATE_STORE_CHECKPOINT_LOCATION_NOT_EMPTY" : {
"message" : [
"The checkpoint location <checkpointLocation> should be empty on batch 0",
Expand Down Expand Up @@ -5413,6 +5424,14 @@
},
"sqlState" : "42616"
},
"STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED" : {
"message" : [
"Reading state across different checkpoint format versions is not supported.",
"startBatchId=<startBatchId>, endBatchId=<endBatchId>.",
"startFormatVersion=<startFormatVersion>, endFormatVersion=<endFormatVersion>."
],
"sqlState" : "KD002"
},
"STDS_NO_PARTITION_DISCOVERED_IN_STATE_STORE" : {
"message" : [
"The state does not have any partition. Please double check that the query points to the valid state. options: <sourceOptions>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,30 @@ def conf(cls):
return cfg


class TransformWithStateInPandasWithCheckpointV2TestsMixin(TransformWithStateInPandasTestsMixin):
@classmethod
def conf(cls):
cfg = super().conf()
cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2")
return cfg

# TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId
def test_transform_with_value_state_metadata(self):
pass


class TransformWithStateInPySparkWithCheckpointV2TestsMixin(TransformWithStateInPySparkTestsMixin):
@classmethod
def conf(cls):
cfg = super().conf()
cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2")
return cfg

# TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId
def test_transform_with_value_state_metadata(self):
pass


class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase):
pass

Expand All @@ -1924,6 +1948,18 @@ class TransformWithStateInPySparkTests(TransformWithStateInPySparkTestsMixin, Re
pass


class TransformWithStateInPandasWithCheckpointV2Tests(
TransformWithStateInPandasWithCheckpointV2TestsMixin, ReusedSQLTestCase
):
pass


class TransformWithStateInPySparkWithCheckpointV2Tests(
TransformWithStateInPySparkWithCheckpointV2TestsMixin, ReusedSQLTestCase
):
pass


if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_transform_with_state import * # noqa: F401

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2740,6 +2740,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}

def invalidCheckpointLineage(lineage: String, message: String): Throwable = {
new SparkException(
errorClass = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE",
messageParameters = Map(
"lineage" -> lineage,
"message" -> message
),
cause = null
)
}

def notEnoughMemoryToLoadStore(
stateStoreId: String,
stateStoreProviderName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ case class StateSourceOptions(
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean,
operatorStateUniqueIds: Option[Array[Array[String]]] = None) {
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand Down Expand Up @@ -576,29 +577,52 @@ object StateSourceOptions extends DataSourceOptions {
batchId.get
}

val operatorStateUniqueIds = getOperatorStateUniqueIds(
val endBatchId = if (readChangeFeedOptions.isDefined) {
readChangeFeedOptions.get.changeEndBatchId
} else {
batchId.get
}

val startOperatorStateUniqueIds = getOperatorStateUniqueIds(
sparkSession,
startBatchId,
operatorId,
resolvedCpLocation)

if (operatorStateUniqueIds.isDefined) {
val endOperatorStateUniqueIds = if (startBatchId == endBatchId) {
startOperatorStateUniqueIds
} else {
getOperatorStateUniqueIds(
sparkSession,
endBatchId,
operatorId,
resolvedCpLocation)
}

if (startOperatorStateUniqueIds.isDefined != endOperatorStateUniqueIds.isDefined) {
val startFormatVersion = if (startOperatorStateUniqueIds.isDefined) 2 else 1
val endFormatVersion = if (endOperatorStateUniqueIds.isDefined) 2 else 1
throw StateDataSourceErrors.mixedCheckpointFormatVersionsNotSupported(
startBatchId,
endBatchId,
startFormatVersion,
endFormatVersion
)
}

if (startOperatorStateUniqueIds.isDefined) {
if (fromSnapshotOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID,
"Snapshot reading is currently not supported with checkpoint v2.")
}
if (readChangeFeedOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
READ_CHANGE_FEED,
"Read change feed is currently not supported with checkpoint v2.")
}
}

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds)
stateVarName, readRegisteredTimers, flattenCollectionTypes,
startOperatorStateUniqueIds, endOperatorStateUniqueIds)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ object StateDataSourceErrors {
sourceOptions: StateSourceOptions): StateDataSourceException = {
new StateDataSourceNoPartitionDiscoveredInStateStore(sourceOptions)
}

def mixedCheckpointFormatVersionsNotSupported(
startBatchId: Long,
endBatchId: Long,
startFormatVersion: Int,
endFormatVersion: Int): StateDataSourceException = {
new StateDataSourceMixedCheckpointFormatVersionsNotSupported(
startBatchId,
endBatchId,
startFormatVersion,
endFormatVersion)
}
}

abstract class StateDataSourceException(
Expand Down Expand Up @@ -172,3 +184,18 @@ class StateDataSourceReadOperatorMetadataFailure(
"STDS_FAILED_TO_READ_OPERATOR_METADATA",
Map("checkpointLocation" -> checkpointLocation, "batchId" -> batchId.toString),
cause = null)

class StateDataSourceMixedCheckpointFormatVersionsNotSupported(
startBatchId: Long,
endBatchId: Long,
startFormatVersion: Int,
endFormatVersion: Int)
extends StateDataSourceException(
"STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED",
Map(
"startBatchId" -> startBatchId.toString,
"endBatchId" -> endBatchId.toString,
"startFormatVersion" -> startFormatVersion.toString,
"endFormatVersion" -> endFormatVersion.toString
),
cause = null)
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,20 @@ abstract class StatePartitionReaderBase(
schema, "value").asInstanceOf[StructType]
}

protected val getStoreUniqueId : Option[String] = {
protected def getStoreUniqueId(
operatorStateUniqueIds: Option[Array[Array[String]]]) : Option[String] = {
SymmetricHashJoinStateManager.getStateStoreCheckpointId(
storeName = partition.sourceOptions.storeName,
partitionId = partition.partition,
stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds)
stateStoreCkptIds = operatorStateUniqueIds)
}

protected def getStartStoreUniqueId: Option[String] = {
getStoreUniqueId(partition.sourceOptions.startOperatorStateUniqueIds)
}

protected def getEndStoreUniqueId: Option[String] = {
getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds)
}

protected lazy val provider: StateStoreProvider = {
Expand All @@ -123,7 +132,7 @@ abstract class StatePartitionReaderBase(
if (useColFamilies) {
val store = provider.getStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId)
getEndStoreUniqueId)
require(stateStoreColFamilySchemaOpt.isDefined)
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
Expand Down Expand Up @@ -182,9 +191,11 @@ class StatePartitionReader(
private lazy val store: ReadStateStore = {
partition.sourceOptions.fromSnapshotOptions match {
case None =>
assert(getStartStoreUniqueId == getEndStoreUniqueId,
"Start and end store unique IDs must be the same when not reading from snapshot")
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId
getStartStoreUniqueId
)

case Some(fromSnapshotOptions) =>
Expand Down Expand Up @@ -261,7 +272,8 @@ class StateStoreChangeDataPartitionReader(
.getStateStoreChangeDataReader(
partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1,
partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1,
colFamilyNameOpt)
colFamilyNameOpt,
getEndStoreUniqueId)
}

override lazy val iter: Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,22 @@ class StreamStreamJoinStatePartitionReader(
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)

private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
private val startStateStoreCheckpointIds =
SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.operatorStateUniqueIds,
partition.sourceOptions.startOperatorStateUniqueIds,
usesVirtualColumnFamilies)

private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyToNumValues
startStateStoreCheckpointIds.left.keyToNumValues
} else {
stateStoreCheckpointIds.right.keyToNumValues
startStateStoreCheckpointIds.right.keyToNumValues
}

private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyWithIndexToValue
startStateStoreCheckpointIds.left.keyWithIndexToValue
} else {
stateStoreCheckpointIds.right.keyWithIndexToValue
startStateStoreCheckpointIds.right.keyWithIndexToValue
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ case class TransformWithStateInPySparkExec(
store.abort()
}
}
setStoreMetrics(store)
setStoreMetrics(store, isStreaming)
setOperatorMetrics()
}).map { row =>
numOutputRows += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,14 @@ trait StateStoreWriter
* Set the SQL metrics related to the state store.
* This should be called in that task after the store has been updated.
*/
protected def setStoreMetrics(store: StateStore): Unit = {
protected def setStoreMetrics(store: StateStore, setCheckpointInfo: Boolean = true): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why do we need this change ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In setStoreMetrics we call store.getStateStoreCheckpointInfo(). If we call this in the store.abort() case in TransformWithStateExec or TransformWithStateInPySparkExec it will throw an exception since the checkpoint info does not exist since we never committed. https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala#L343

val storeMetrics = store.metrics
longMetric("numTotalStateRows") += storeMetrics.numKeys
longMetric("stateMemory") += storeMetrics.memoryUsedBytes
setStoreCustomMetrics(storeMetrics.customMetrics)
setStoreInstanceMetrics(storeMetrics.instanceMetrics)

if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) {
if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf) && setCheckpointInfo) {
// Set the state store checkpoint information for the driver to collect
val ssInfo = store.getStateStoreCheckpointInfo()
setStateStoreCheckpointInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ case class TransformWithStateExec(
store.abort()
}
}
setStoreMetrics(store)
setStoreMetrics(store, isStreaming)
setOperatorMetrics()
closeStatefulProcessor()
statefulProcessor.setHandle(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._

import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
Expand Down Expand Up @@ -292,9 +292,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
/** Get the state store for making updates to create a new `version` of the store. */
override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = {
if (uniqueId.isDefined) {
throw QueryExecutionErrors.cannotLoadStore(new SparkException(
throw StateStoreErrors.stateStoreCheckpointIdsNotSupported(
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " +
"but a state store checkpointID is passed in"))
"but a state store checkpointID is passed in")
}
val newMap = getLoadedMapForStore(version)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)} " +
Expand Down Expand Up @@ -369,10 +369,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
assert(
!storeConf.enableStateStoreCheckpointIds,
"HDFS State Store Provider doesn't support checkpointFormatVersion >= 2 " +
s"checkpointFormatVersion ${storeConf.stateStoreCheckpointFormatVersion}")
if (storeConf.enableStateStoreCheckpointIds) {
throw StateStoreErrors.stateStoreCheckpointIdsNotSupported(
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1")
}

this.stateStoreId_ = stateStoreId
this.keySchema = keySchema
Expand Down Expand Up @@ -1064,8 +1064,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
override def getStateStoreChangeDataReader(
startVersion: Long,
endVersion: Long,
colFamilyNameOpt: Option[String] = None):
colFamilyNameOpt: Option[String] = None,
endVersionStateStoreCkptId: Option[String] = None):
StateStoreChangeDataReader = {

if (endVersionStateStoreCkptId.isDefined) {
throw StateStoreErrors.stateStoreCheckpointIdsNotSupported(
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " +
"but a state store checkpointID is passed in")
}

// Multiple column families are not supported with HDFSBackedStateStoreProvider
if (colFamilyNameOpt.isDefined) {
throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
Expand Down Expand Up @@ -1099,7 +1107,7 @@ class HDFSBackedStateStoreChangeDataReader(
extends StateStoreChangeDataReader(
fm, stateLocation, startVersion, endVersion, compressionCodec) {

override protected var changelogSuffix: String = "delta"
override protected val changelogSuffix: String = "delta"

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
val reader = currentChangelogReader()
Expand Down
Loading