diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 06f8d3a78252..efb054ef95d7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3835,7 +3835,7 @@ "STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : { "message" : [ "The given State Store Provider does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.", - "Therefore, it does not support option snapshotStartBatchId in state data source." + "Therefore, it does not support option snapshotStartBatchId or readChangeFeed in state data source." ], "sqlState" : "42K06" }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 213573a756bc..e2c5499fe439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -94,10 +94,20 @@ class StateDataSource extends TableProvider with DataSourceRegister { manager.readSchemaFile() } - new StructType() - .add("key", keySchema) - .add("value", valueSchema) - .add("partition_id", IntegerType) + if (sourceOptions.readChangeFeed) { + new StructType() + .add("batch_id", LongType) + .add("change_type", StringType) + .add("key", keySchema) + .add("value", valueSchema) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", keySchema) + .add("value", valueSchema) + .add("partition_id", IntegerType) + } + } catch { case NonFatal(e) => throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e) @@ -125,21 +135,38 @@ class StateDataSource extends TableProvider with DataSourceRegister { override def supportsExternalMetadata(): Boolean = false } +case class FromSnapshotOptions( + snapshotStartBatchId: Long, + snapshotPartitionId: Int) + +case class ReadChangeFeedOptions( + changeStartBatchId: Long, + changeEndBatchId: Long +) + case class StateSourceOptions( resolvedCpLocation: String, batchId: Long, operatorId: Int, storeName: String, joinSide: JoinSideValues, - snapshotStartBatchId: Option[Long], - snapshotPartitionId: Option[Int]) { + readChangeFeed: Boolean, + fromSnapshotOptions: Option[FromSnapshotOptions], + readChangeFeedOptions: Option[ReadChangeFeedOptions]) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { - s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + - s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"snapshotStartBatchId=${snapshotStartBatchId.getOrElse("None")}, " + - s"snapshotPartitionId=${snapshotPartitionId.getOrElse("None")})" + var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide" + if (fromSnapshotOptions.isDefined) { + desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" + desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" + } + if (readChangeFeedOptions.isDefined) { + desc += s", changeStartBatchId=${readChangeFeedOptions.get.changeStartBatchId}" + desc += s", changeEndBatchId=${readChangeFeedOptions.get.changeEndBatchId}" + } + desc + ")" } } @@ -151,6 +178,9 @@ object StateSourceOptions extends DataSourceOptions { val JOIN_SIDE = newOption("joinSide") val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId") val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId") + val READ_CHANGE_FEED = newOption("readChangeFeed") + val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") + val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -172,16 +202,6 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.requiredOptionUnspecified(PATH) }.get - val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation) - - val batchId = Option(options.get(BATCH_ID)).map(_.toLong).orElse { - Some(getLastCommittedBatch(sparkSession, resolvedCpLocation)) - }.get - - if (batchId < 0) { - throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID) - } - val operatorId = Option(options.get(OPERATOR_ID)).map(_.toInt) .orElse(Some(0)).get @@ -210,30 +230,97 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } - val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong) - if (snapshotStartBatchId.exists(_ < 0)) { - throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID) - } else if (snapshotStartBatchId.exists(_ > batchId)) { - throw StateDataSourceErrors.invalidOptionValue( - SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId") - } + val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation) + + var batchId = Option(options.get(BATCH_ID)).map(_.toLong) + val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong) val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt) - if (snapshotPartitionId.exists(_ < 0)) { - throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID) - } - // both snapshotPartitionId and snapshotStartBatchId are required at the same time, because - // each partition may have different checkpoint status - if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) { - throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID) - } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) { - throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) + val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) + + val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) + var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) + + var fromSnapshotOptions: Option[FromSnapshotOptions] = None + var readChangeFeedOptions: Option[ReadChangeFeedOptions] = None + + if (readChangeFeed) { + if (joinSide != JoinSideValues.none) { + throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, READ_CHANGE_FEED)) + } + if (batchId.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(BATCH_ID, READ_CHANGE_FEED)) + } + if (snapshotStartBatchId.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(SNAPSHOT_START_BATCH_ID, READ_CHANGE_FEED)) + } + if (snapshotPartitionId.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(SNAPSHOT_PARTITION_ID, READ_CHANGE_FEED)) + } + + if (changeStartBatchId.isEmpty) { + throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID) + } + changeEndBatchId = Some( + changeEndBatchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) + + // changeStartBatchId and changeEndBatchId must all be defined at this point + if (changeStartBatchId.get < 0) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(CHANGE_START_BATCH_ID) + } + if (changeEndBatchId.get < changeStartBatchId.get) { + throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, + s"$CHANGE_END_BATCH_ID cannot be smaller than $CHANGE_START_BATCH_ID. " + + s"Please check the input to $CHANGE_END_BATCH_ID, or if you are using its default " + + s"value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}.") + } + + batchId = Some(changeEndBatchId.get) + + readChangeFeedOptions = Option( + ReadChangeFeedOptions(changeStartBatchId.get, changeEndBatchId.get)) + } else { + if (changeStartBatchId.isDefined) { + throw StateDataSourceErrors.invalidOptionValue(CHANGE_START_BATCH_ID, + s"Only specify this option when $READ_CHANGE_FEED is set to true.") + } + if (changeEndBatchId.isDefined) { + throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, + s"Only specify this option when $READ_CHANGE_FEED is set to true.") + } + + batchId = Some(batchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) + + if (batchId.get < 0) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID) + } + if (snapshotStartBatchId.exists(_ < 0)) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID) + } else if (snapshotStartBatchId.exists(_ > batchId.get)) { + throw StateDataSourceErrors.invalidOptionValue( + SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to ${batchId.get}") + } + if (snapshotPartitionId.exists(_ < 0)) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID) + } + // both snapshotPartitionId and snapshotStartBatchId are required at the same time, because + // each partition may have different checkpoint status + if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) { + throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID) + } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) { + throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) + } + + if (snapshotStartBatchId.isDefined && snapshotPartitionId.isDefined) { + fromSnapshotOptions = Some( + FromSnapshotOptions(snapshotStartBatchId.get, snapshotPartitionId.get)) + } } StateSourceOptions( - resolvedCpLocation, batchId, operatorId, storeName, - joinSide, snapshotStartBatchId, snapshotPartitionId) + resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, + readChangeFeed, fromSnapshotOptions, readChangeFeedOptions) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 8461603e9652..6201cf1157ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -23,8 +23,10 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, Par import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{NextIterator, SerializableConfiguration} /** * An implementation of [[PartitionReaderFactory]] for State data source. This is used to support @@ -37,8 +39,14 @@ class StatePartitionReaderFactory( stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - new StatePartitionReader(storeConf, hadoopConf, - partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) + val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] + if (stateStoreInputPartition.sourceOptions.readChangeFeed) { + new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, + stateStoreInputPartition, schema, stateStoreMetadata) + } else { + new StatePartitionReader(storeConf, hadoopConf, + stateStoreInputPartition, schema, stateStoreMetadata) + } } } @@ -46,18 +54,17 @@ class StatePartitionReaderFactory( * An implementation of [[PartitionReader]] for State data source. This is used to support * general read from a state store instance, rather than specific to the operator. */ -class StatePartitionReader( +abstract class StatePartitionReaderBase( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReader[InternalRow] with Logging { - private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] - private lazy val provider: StateStoreProvider = { + protected lazy val provider: StateStoreProvider = { val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) @@ -88,25 +95,7 @@ class StatePartitionReader( useMultipleValuesPerKey = false) } - private lazy val store: ReadStateStore = { - partition.sourceOptions.snapshotStartBatchId match { - case None => provider.getReadStore(partition.sourceOptions.batchId + 1) - - case Some(snapshotStartBatchId) => - if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { - throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( - provider.getClass.toString) - } - provider.asInstanceOf[SupportsFineGrainedReplay] - .replayReadStateFromSnapshot( - snapshotStartBatchId + 1, - partition.sourceOptions.batchId + 1) - } - } - - private lazy val iter: Iterator[InternalRow] = { - store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) - } + protected val iter: Iterator[InternalRow] private var current: InternalRow = _ @@ -124,9 +113,46 @@ class StatePartitionReader( override def close(): Unit = { current = null - store.abort() provider.close() } +} + +/** + * An implementation of [[StatePartitionReaderBase]] for the normal mode of State Data + * Source. It reads the the state at a particular batchId. + */ +class StatePartitionReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + + private lazy val store: ReadStateStore = { + partition.sourceOptions.fromSnapshotOptions match { + case None => provider.getReadStore(partition.sourceOptions.batchId + 1) + + case Some(fromSnapshotOptions) => + if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { + throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( + provider.getClass.toString) + } + provider.asInstanceOf[SupportsFineGrainedReplay] + .replayReadStateFromSnapshot( + fromSnapshotOptions.snapshotStartBatchId + 1, + partition.sourceOptions.batchId + 1) + } + } + + override lazy val iter: Iterator[InternalRow] = { + store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) + } + + override def close(): Unit = { + store.abort() + super.close() + } private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { val row = new GenericInternalRow(3) @@ -136,3 +162,48 @@ class StatePartitionReader( row } } + +/** + * An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data + * Source. It reads the change of state over batches of a particular partition. + */ +class StateStoreChangeDataPartitionReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + + private lazy val changeDataReader: + NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = { + if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { + throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( + provider.getClass.toString) + } + provider.asInstanceOf[SupportsFineGrainedReplay] + .getStateStoreChangeDataReader( + partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1, + partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1) + } + + override lazy val iter: Iterator[InternalRow] = { + changeDataReader.iterator.map(unifyStateChangeDataRow) + } + + override def close(): Unit = { + changeDataReader.closeIfNeeded() + super.close() + } + + private def unifyStateChangeDataRow(row: (RecordType, UnsafeRow, UnsafeRow, Long)): + InternalRow = { + val result = new GenericInternalRow(5) + result.update(0, row._4) + result.update(1, UTF8String.fromString(getRecordTypeAsString(row._1))) + result.update(2, row._2) + result.update(3, row._3) + result.update(4, partition.partition) + result + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 821a36977fed..01f966ae948a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -86,17 +86,18 @@ class StateScan( assert((tail - head + 1) == partitionNums.length, s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}") - sourceOptions.snapshotPartitionId match { + sourceOptions.fromSnapshotOptions match { case None => partitionNums.map { pn => new StateStoreInputPartition(pn, queryId, sourceOptions) }.toArray - case Some(snapshotPartitionId) => - if (partitionNums.contains(snapshotPartitionId)) { - Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions)) + case Some(fromSnapshotOptions) => + if (partitionNums.contains(fromSnapshotOptions.snapshotPartitionId)) { + Array(new StateStoreInputPartition( + fromSnapshotOptions.snapshotPartitionId, queryId, sourceOptions)) } else { throw StateStoreErrors.stateStoreSnapshotPartitionNotFound( - snapshotPartitionId, sourceOptions.operatorId, + fromSnapshotOptions.snapshotPartitionId, sourceOptions.operatorId, sourceOptions.stateCheckpointLocation.toString) } } @@ -128,16 +129,27 @@ class StateScan( override def toBatch: Batch = this override def description(): String = { - val desc = s"StateScan " + + var desc = s"StateScan " + s"[stateCkptLocation=${sourceOptions.stateCheckpointLocation}]" + s"[batchId=${sourceOptions.batchId}][operatorId=${sourceOptions.operatorId}]" + s"[storeName=${sourceOptions.storeName}]" if (sourceOptions.joinSide != JoinSideValues.none) { - desc + s"[joinSide=${sourceOptions.joinSide}]" - } else { - desc + desc += s"[joinSide=${sourceOptions.joinSide}]" + } + sourceOptions.fromSnapshotOptions match { + case Some(fromSnapshotOptions) => + desc += s"[snapshotStartBatchId=${fromSnapshotOptions.snapshotStartBatchId}]" + desc += s"[snapshotPartitionId=${fromSnapshotOptions.snapshotPartitionId}]" + case _ => + } + sourceOptions.readChangeFeedOptions match { + case Some(fromSnapshotOptions) => + desc += s"[changeStartBatchId=${fromSnapshotOptions.changeStartBatchId}" + desc += s"[changeEndBatchId=${fromSnapshotOptions.changeEndBatchId}" + case _ => } + desc } private def stateCheckpointPartitionsLocation: Path = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 2d2c9631e537..2fc85cd8aa96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.Jo import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state.StateStoreConf -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -59,11 +59,17 @@ class StateTable( if (sourceOptions.joinSide != JoinSideValues.none) { desc += s"[joinSide=${sourceOptions.joinSide}]" } - if (sourceOptions.snapshotStartBatchId.isDefined) { - desc += s"[snapshotStartBatchId=${sourceOptions.snapshotStartBatchId}]" + sourceOptions.fromSnapshotOptions match { + case Some(fromSnapshotOptions) => + desc += s"[snapshotStartBatchId=${fromSnapshotOptions.snapshotStartBatchId}]" + desc += s"[snapshotPartitionId=${fromSnapshotOptions.snapshotPartitionId}]" + case _ => } - if (sourceOptions.snapshotPartitionId.isDefined) { - desc += s"[snapshotPartitionId=${sourceOptions.snapshotPartitionId}]" + sourceOptions.readChangeFeedOptions match { + case Some(fromSnapshotOptions) => + desc += s"[changeStartBatchId=${fromSnapshotOptions.changeStartBatchId}" + desc += s"[changeEndBatchId=${fromSnapshotOptions.changeEndBatchId}" + case _ => } desc } @@ -76,16 +82,26 @@ class StateTable( override def properties(): util.Map[String, String] = Map.empty[String, String].asJava private def isValidSchema(schema: StructType): Boolean = { - if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { + val expectedFieldNames = + if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "value", "partition_id") + } else { + Seq("key", "value", "partition_id") + } + val expectedTypes = Map( + "batch_id" -> classOf[LongType], + "change_type" -> classOf[StringType], + "key" -> classOf[StructType], + "value" -> classOf[StructType], + "partition_id" -> classOf[IntegerType]) + + if (schema.fieldNames.toImmutableArraySeq != expectedFieldNames) { false } else { - true + schema.fieldNames.forall { fieldName => + expectedTypes(fieldName).isAssignableFrom( + SchemaUtil.getSchemaAsDataType(schema, fieldName).getClass) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 91f42db46dfb..673ec3414c23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -117,7 +117,8 @@ class StreamStreamJoinStatePartitionReader( formatVersion, skippedNullValueCount = None, useStateStoreCoordinator = false, - snapshotStartVersion = partition.sourceOptions.snapshotStartBatchId.map(_ + 1) + snapshotStartVersion = + partition.sourceOptions.fromSnapshotOptions.map(_.snapshotStartBatchId + 1) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index c4a41ceb4caf..2ec36166f9f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -978,4 +978,47 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with result } + + override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + StateStoreChangeDataReader = { + new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, endVersion, + CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), + keySchema, valueSchema) + } +} + +/** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */ +class HDFSBackedStateStoreChangeDataReader( + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec, + keySchema: StructType, + valueSchema: StructType) + extends StateStoreChangeDataReader( + fm, stateLocation, startVersion, endVersion, compressionCodec) { + + override protected var changelogSuffix: String = "delta" + + override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + val (recordType, keyArray, valueArray) = reader.next() + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyArray, keyArray.length) + if (valueArray == null) { + (recordType, keyRow, null, currentChangelogVersion - 1) + } else { + val valueRow = new UnsafeRow(valueSchema.fields.length) + // If valueSize in existing file is not multiple of 8, floor it to multiple of 8. + // This is a workaround for the following: + // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in + // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data + valueRow.pointTo(valueArray, (valueArray.length / 8) * 8) + (recordType, keyRow, valueRow, currentChangelogVersion - 1) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 15ef8832ef35..6215d1aaf4b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -324,7 +324,7 @@ class RocksDB( } } } finally { - if (changelogReader != null) changelogReader.close() + if (changelogReader != null) changelogReader.closeIfNeeded() } } loadedVersion = endVersion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 497d48946448..a5a8d27116ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -461,6 +465,19 @@ private[sql] class RocksDBStateStoreProvider } } + override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + StateStoreChangeDataReader = { + val statePath = stateStoreId.storeCheckpointLocation() + val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + new RocksDBStateStoreChangeDataReader( + CheckpointFileManager.create(statePath, hadoopConf), + statePath, + startVersion, + endVersion, + CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), + keyValueEncoderMap) + } + /** * Class for column family related utility functions. * Verification functions for column family names, column family operation validations etc. @@ -670,3 +687,36 @@ object RocksDBStateStoreProvider { CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES) } + +/** [[StateStoreChangeDataReader]] implementation for [[RocksDBStateStoreProvider]] */ +class RocksDBStateStoreChangeDataReader( + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec, + keyValueEncoderMap: + ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]) + extends StateStoreChangeDataReader( + fm, stateLocation, startVersion, endVersion, compressionCodec) { + + override protected var changelogSuffix: String = "changelog" + + override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + val (recordType, keyArray, valueArray) = reader.next() + // Todo: does not support multiple virtual column families + val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) = + keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) + val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray) + if (valueArray == null) { + (recordType, keyRow, null, currentChangelogVersion - 1) + } else { + val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray) + (recordType, keyRow, valueRow, currentChangelogVersion - 1) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 76fd36bd726a..0dc5414b7398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} /** * Base trait for a versioned key-value store which provides read operations. Each instance of a @@ -439,9 +439,9 @@ object StateStoreProvider { } /** - * This is an optional trait to be implemented by [[StateStoreProvider]]s that can read fine - * grained state data which is replayed from a specific snapshot version. It is used by the - * snapshotStartBatchId option in state data source. + * This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change + * of state store over batches. This is used by State Data Source with additional options like + * snapshotStartBatchId or readChangeFeed. */ trait SupportsFineGrainedReplay { @@ -469,6 +469,22 @@ trait SupportsFineGrainedReplay { def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore = { new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion)) } + + /** + * Return an iterator that reads all the entries of changelogs from startVersion to + * endVersion. + * Each record is represented by a tuple of (recordType: [[RecordType.Value]], key: [[UnsafeRow]], + * value: [[UnsafeRow]], batchId: [[Long]]) + * A put record is returned as a tuple(recordType, key, value, batchId) + * A delete record is return as a tuple(recordType, key, null, batchId) + * + * @param startVersion starting changelog version + * @param endVersion ending changelog version + * @return iterator that gives tuple(recordType: [[RecordType.Value]], nested key: [[UnsafeRow]], + * nested value: [[UnsafeRow]], batchId: [[Long]]) + */ + def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 04388589bb0b..d189daa6e841 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.{FSError, Path} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream @@ -55,6 +56,15 @@ object RecordType extends Enumeration { } } + def getRecordTypeAsString(recordType: RecordType): String = { + recordType match { + case PUT_RECORD => "update" + case DELETE_RECORD => "delete" + case _ => throw StateStoreErrors.unsupportedOperationException( + "getRecordTypeAsString", recordType.toString) + } + } + // Generate record type from byte representation def getRecordTypeFromByte(byte: Byte): RecordType = { byte match { @@ -260,17 +270,17 @@ abstract class StateStoreChangelogReader( } protected val input: DataInputStream = decompressStream(sourceStream) - def close(): Unit = { if (input != null) input.close() } + override protected def close(): Unit = { if (input != null) input.close() } override def getNext(): (RecordType.Value, Array[Byte], Array[Byte]) } /** * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(recordType: RecordType.Value, - * key: Array[Byte], value: Array[Byte], colFamilyName: String) - * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) - * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + * A record is represented by tuple(recordType: RecordType.Value, + * key: Array[Byte], value: Array[Byte]) + * A put record is returned as a tuple(recordType, key, value) + * A delete record is return as a tuple(recordType, key, null) */ class StateStoreChangelogReaderV1( fm: CheckpointFileManager, @@ -307,10 +317,10 @@ class StateStoreChangelogReaderV1( /** * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(recordType: RecordType.Value, - * key: Array[Byte], value: Array[Byte], colFamilyName: String) - * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) - * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + * A record is represented by tuple(recordType: RecordType.Value, + * key: Array[Byte], value: Array[Byte]) + * A put record is returned as a tuple(recordType, key, value) + * A delete record is return as a tuple(recordType, key, null) */ class StateStoreChangelogReaderV2( fm: CheckpointFileManager, @@ -355,3 +365,84 @@ class StateStoreChangelogReaderV2( } } } + +/** + * Base class representing a iterator that iterates over a range of changelog files in a state + * store. In each iteration, it will return a tuple of (changeType: [[RecordType]], + * nested key: [[UnsafeRow]], nested value: [[UnsafeRow]], batchId: [[Long]]) + * + * @param fm checkpoint file manager used to manage streaming query checkpoint + * @param stateLocation location of the state store + * @param startVersion start version of the changelog file to read + * @param endVersion end version of the changelog file to read + * @param compressionCodec de-compression method using for reading changelog file + */ +abstract class StateStoreChangeDataReader( + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec) + extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging { + + assert(startVersion >= 1) + assert(endVersion >= startVersion) + + /** + * Iterator that iterates over the changelog files in the state store. + */ + private class ChangeLogFileIterator extends Iterator[Path] { + + private var currentVersion = StateStoreChangeDataReader.this.startVersion - 1 + + /** returns the version of the changelog returned by the latest [[next]] function call */ + def getVersion: Long = currentVersion + + override def hasNext: Boolean = currentVersion < StateStoreChangeDataReader.this.endVersion + + override def next(): Path = { + currentVersion += 1 + getChangelogPath(currentVersion) + } + + private def getChangelogPath(version: Long): Path = + new Path( + StateStoreChangeDataReader.this.stateLocation, + s"$version.${StateStoreChangeDataReader.this.changelogSuffix}") + } + + /** file format of the changelog files */ + protected var changelogSuffix: String + private lazy val fileIterator = new ChangeLogFileIterator + private var changelogReader: StateStoreChangelogReader = null + + /** + * Get a changelog reader that has at least one record left to read. If there is no readers left, + * return null. + */ + protected def currentChangelogReader(): StateStoreChangelogReader = { + while (changelogReader == null || !changelogReader.hasNext) { + if (changelogReader != null) { + changelogReader.closeIfNeeded() + changelogReader = null + } + if (!fileIterator.hasNext) { + finished = true + return null + } + // Todo: Does not support StateStoreChangelogReaderV2 + changelogReader = + new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec) + } + changelogReader + } + + /** get the version of the current changelog reader */ + protected def currentChangelogVersion: Long = fileIterator.getVersion + + override def close(): Unit = { + if (changelogReader != null) { + changelogReader.closeIfNeeded() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala new file mode 100644 index 000000000000..2858d356d4c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.hadoop.conf.Configuration +import org.scalatest.Assertions + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +class HDFSBackedStateDataSourceChangeDataReaderSuite extends StateDataSourceChangeDataReaderSuite { + override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider = + new HDFSBackedStateStoreProvider +} + +class RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite extends + StateDataSourceChangeDataReaderSuite { + override protected def newStateStoreProvider(): RocksDBStateStoreProvider = + new RocksDBStateStoreProvider + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", + "true") + } +} + +abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestBase + with Assertions { + + import testImplicits._ + import StateStoreTestsHelper._ + + protected val keySchema: StructType = StateStoreTestsHelper.keySchema + protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema + + protected def newStateStoreProvider(): StateStoreProvider + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED, false) + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, newStateStoreProvider().getClass.getName) + } + + /** + * Calls the overridable [[newStateStoreProvider]] to create the state store provider instance. + * Initialize it with the configuration set by child classes. + * + * @param checkpointDir path to store state information + * @return instance of class extending [[StateStoreProvider]] + */ + private def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = { + val provider = newStateStoreProvider() + provider.init( + StateStoreId(checkpointDir, 0, 0), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + useColumnFamilies = false, + StateStoreConf(spark.sessionState.conf), + new Configuration) + provider + } + + test("ERROR: specify changeStartBatchId in normal mode") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + } + } + + test("ERROR: changeStartBatchId is set to negative") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceInvalidOptionValueIsNegative] { + spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, -1) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.IS_NEGATIVE") + } + } + + test("ERROR: changeEndBatchId is set to less than changeStartBatchId") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 1) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + } + } + + test("ERROR: joinSide option is used together with readChangeFeed") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceConflictOptions] { + spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.JOIN_SIDE, "left") + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_CONFLICT_OPTIONS") + } + } + + test("getChangeDataReader of state store provider") { + def withNewStateStore(provider: StateStoreProvider, version: Int)(f: StateStore => Unit): + Unit = { + val stateStore = provider.getStore(version) + f(stateStore) + stateStore.commit() + } + + withTempDir { tempDir => + val provider = getNewStateStoreProvider(tempDir.getAbsolutePath) + withNewStateStore(provider, 0) { stateStore => + put(stateStore, "a", 1, 1) } + withNewStateStore(provider, 1) { stateStore => + put(stateStore, "b", 2, 2) } + withNewStateStore(provider, 2) { stateStore => + stateStore.remove(dataToKeyRow("a", 1)) } + withNewStateStore(provider, 3) { stateStore => + stateStore.remove(dataToKeyRow("b", 2)) } + + val reader = + provider.asInstanceOf[SupportsFineGrainedReplay].getStateStoreChangeDataReader(1, 4) + + assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("a", 1), dataToValueRow(1), 0L)) + assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("b", 2), dataToValueRow(2), 1L)) + assert(reader.next() === + (RecordType.DELETE_RECORD, dataToKeyRow("a", 1), null, 2L)) + assert(reader.next() === + (RecordType.DELETE_RECORD, dataToKeyRow("b", 2), null, 3L)) + } + } + + test("read global streaming limit state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().limit(10) + testStream(df)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4), + ProcessAllAvailable(), + AddData(inputData, 5, 6, 7, 8), + ProcessAllAvailable(), + AddData(inputData, 9, 10, 11, 12), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(null), Row(4), 0), + Row(1L, "update", Row(null), Row(8), 0), + Row(2L, "update", Row(null), Row(10), 0) + ) + + checkAnswer(stateDf, expectedDf) + } + } + + test("read streaming aggregate state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().groupBy("value").count() + testStream(df, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4), + ProcessAllAvailable(), + AddData(inputData, 2, 3, 4, 5), + ProcessAllAvailable(), + AddData(inputData, 3, 4, 5, 6), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(3), Row(1), 1), + Row(1L, "update", Row(3), Row(2), 1), + Row(1L, "update", Row(5), Row(1), 1), + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(0L, "update", Row(4), Row(1), 2), + Row(1L, "update", Row(4), Row(2), 2), + Row(2L, "update", Row(4), Row(3), 2), + Row(0L, "update", Row(1), Row(1), 3), + Row(0L, "update", Row(2), Row(1), 4), + Row(1L, "update", Row(2), Row(2), 4), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf, expectedDf) + } + } + + test("read streaming deduplication state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().dropDuplicates("value") + testStream(df, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4), + ProcessAllAvailable(), + AddData(inputData, 2, 3, 4, 5), + ProcessAllAvailable(), + AddData(inputData, 3, 4, 5, 6), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(1), Row(null), 3), + Row(0L, "update", Row(2), Row(null), 4), + Row(0L, "update", Row(3), Row(null), 1), + Row(0L, "update", Row(4), Row(null), 2), + Row(1L, "update", Row(5), Row(null), 1), + Row(2L, "update", Row(6), Row(null), 4) + ) + + checkAnswer(stateDf, expectedDf) + } + } + + test("read stream-stream join state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[(Int, Long)] + val leftDf = + inputData.toDF().select(col("_1").as("leftKey"), col("_2").as("leftValue")) + val rightDf = + inputData.toDF().select((col("_1") * 2).as("rightKey"), col("_2").as("rightValue")) + val df = leftDf.join(rightDf).where("leftKey == rightKey") + + testStream(df)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, (1, 1L), (2, 2L)), + ProcessAllAvailable(), + AddData(inputData, (3, 3L), (4, 4L)), + ProcessAllAvailable() + ) + + val keyWithIndexToValueDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "left-keyWithIndexToValue") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + val keyWithIndexToValueExpectedDf = Seq( + Row(1L, "update", Row(3, 0L), Row(3, 3L, false), 1), + Row(1L, "update", Row(4, 0L), Row(4, 4L, true), 2), + Row(0L, "update", Row(1, 0L), Row(1, 1L, false), 3), + Row(0L, "update", Row(2, 0L), Row(2, 2L, false), 4), + Row(0L, "update", Row(2, 0L), Row(2, 2L, true), 4) + ) + + checkAnswer(keyWithIndexToValueDf, keyWithIndexToValueExpectedDf) + + val keyToNumValuesDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "left-keyToNumValues") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + val keyToNumValuesDfExpectedDf = Seq( + Row(1L, "update", Row(3), Row(1L), 1), + Row(1L, "update", Row(4), Row(1L), 2), + Row(0L, "update", Row(1), Row(1L), 3), + Row(0L, "update", Row(2), Row(1L), 4) + ) + + checkAnswer(keyToNumValuesDf, keyToNumValuesDfExpectedDf) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala index f5392cc823f7..705d9f125964 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala @@ -383,7 +383,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { } } - private def getStreamStreamJoinQuery(inputStream: MemoryStream[(Int, Long)]): DataFrame = { + protected def getStreamStreamJoinQuery(inputStream: MemoryStream[(Int, Long)]): DataFrame = { val df = inputStream.toDS() .select(col("_1").as("value"), timestamp_seconds($"_2").as("timestamp"))