diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 28de21aaf9389..50abb6d18b987 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH -import org.apache.spark.sql.execution.streaming.CheckpointFileManager -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, OperatorStateMetadataLog} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,7 +46,8 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, - numColsPrefixKey: Int) { + numColsPrefixKey: Int, + operatorPropertiesJson: String) { def toRow(): InternalRow = { new GenericInternalRow( Array[Any](operatorId, @@ -55,7 +56,8 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, - numColsPrefixKey)) + numColsPrefixKey, + UTF8String.fromString(operatorPropertiesJson))) } } @@ -110,7 +112,14 @@ class StateMetadataTable extends Table with SupportsRead with SupportsMetadataCo override def comment: String = "Number of columns in prefix key of the state store instance" } - override val metadataColumns: Array[MetadataColumn] = Array(NumColsPrefixKeyColumn) + private object OperatorPropertiesColumn extends MetadataColumn { + override def name: String = "_operatorProperties" + override def dataType: DataType = StringType + override def comment: String = "Json string storing operator properties" + } + + override val metadataColumns: Array[MetadataColumn] = + Array(NumColsPrefixKeyColumn, OperatorPropertiesColumn) } case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition @@ -188,28 +197,55 @@ class StateMetadataPartitionReader( } else Array.empty } - private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + private[sql] def allOperatorStateMetadata: Array[(OperatorStateMetadata, Long)] = { val stateDir = new Path(checkpointLocation, "state") val opIds = fileManager .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted - opIds.map { opId => - new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + opIds.flatMap { opId => + val operatorIdPath = new Path(stateDir, opId.toString) + // check if OperatorStateMetadataV2 path exists, if it does, read it + // otherwise, fall back to OperatorStateMetadataV1 + val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataFilePath(operatorIdPath) + if (fileManager.exists(operatorStateMetadataV2Path)) { + val operatorStateMetadataLog = new OperatorStateMetadataLog( + hadoopConf, operatorStateMetadataV2Path.toString) + operatorStateMetadataLog.listBatchesOnDisk.flatMap { batchId => + operatorStateMetadataLog.get(batchId).map((_, batchId)) + } + } else { + Array((new OperatorStateMetadataReader(operatorIdPath, hadoopConf).read(), -1L)) + } } } private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { - allOperatorStateMetadata.flatMap { operatorStateMetadata => - require(operatorStateMetadata.version == 1) - val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] - operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => - StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, - operatorStateMetadataV1.operatorInfo.operatorName, - stateStoreMetadata.storeName, - stateStoreMetadata.numPartitions, - if (batchIds.nonEmpty) batchIds.head else -1, - if (batchIds.nonEmpty) batchIds.last else -1, - stateStoreMetadata.numColsPrefixKey - ) + allOperatorStateMetadata.flatMap { case (operatorStateMetadata, batchId) => + require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) + operatorStateMetadata match { + case v1: OperatorStateMetadataV1 => + v1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v1.operatorInfo.operatorId, + v1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + stateStoreMetadata.numColsPrefixKey, + "" + ) + } + case v2: OperatorStateMetadataV2 => + v2.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v2.operatorInfo.operatorId, + v2.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + batchId, + batchId, + stateStoreMetadata.numColsPrefixKey, + v2.operatorPropertiesJson + ) + } } } }.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala index feced6810d3a3..9fdefe32f1e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala @@ -17,23 +17,30 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder -import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._ import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec} trait ColumnFamilySchemaUtils { def getValueStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchema def getListStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchema def getMapStateSchema[K, V]( stateName: String, + keyEncoder: ExpressionEncoder[Any], userKeyEnc: Encoder[K], + valEncoder: Encoder[V], hasTtl: Boolean): ColumnFamilySchema } @@ -41,44 +48,39 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { def getValueStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( stateName, - KEY_ROW_SCHEMA, - if (hasTtl) { - VALUE_ROW_SCHEMA_WITH_TTL - } else { - VALUE_ROW_SCHEMA - }, + getKeySchema(keyEncoder.schema), + getValueSchemaWithTTL(valEncoder.schema, hasTtl), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) } def getListStateSchema[T]( stateName: String, + keyEncoder: ExpressionEncoder[Any], + valEncoder: Encoder[T], hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( stateName, - KEY_ROW_SCHEMA, - if (hasTtl) { - VALUE_ROW_SCHEMA_WITH_TTL - } else { - VALUE_ROW_SCHEMA - }, + getKeySchema(keyEncoder.schema), + getValueSchemaWithTTL(valEncoder.schema, hasTtl), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) } def getMapStateSchema[K, V]( stateName: String, + keyEncoder: ExpressionEncoder[Any], userKeyEnc: Encoder[K], + valEncoder: Encoder[V], hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema) + ColumnFamilySchemaV1( stateName, - COMPOSITE_KEY_ROW_SCHEMA, - if (hasTtl) { - VALUE_ROW_SCHEMA_WITH_TTL - } else { - VALUE_ROW_SCHEMA - }, + compositeKeySchema, + getValueSchemaWithTTL(valEncoder.schema, hasTtl), PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), Some(userKeyEnc.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 251cc16acdf43..97589b83ec3cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization @@ -47,10 +48,25 @@ import org.apache.spark.util.ArrayImplicits._ * * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing * files in a directory always shows the latest files. + * @param hadoopConf Hadoop configuration that is used to read / write metadata files. + * @param path Path to the directory that will be used for writing metadata. + * @param metadataCacheEnabled Whether to cache the batches' metadata in memory. */ -class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: String) +class HDFSMetadataLog[T <: AnyRef : ClassTag]( + hadoopConf: Configuration, + path: String, + val metadataCacheEnabled: Boolean = false) extends MetadataLog[T] with Logging { + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + private implicit val formats: Formats = Serialization.formats(NoTypeHints) /** Needed to serialize type T into JSON when using Jackson */ @@ -64,15 +80,12 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: val metadataPath = new Path(path) protected val fileManager = - CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf()) + CheckpointFileManager.create(metadataPath, hadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) } - protected val metadataCacheEnabled: Boolean - = sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_METADATA_CACHE_ENABLED) - /** * Cache the latest two batches. [[StreamExecution]] usually just accesses the latest two batches * when committing offsets, this cache will save some file system operations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 7e7e9d76081e1..3fe5aeae5f637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataWriter} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -85,6 +85,10 @@ class IncrementalExecution( .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** + * This value dictates which schema format version the state schema should be written in + * for all operators other than TransformWithState. + */ private val STATE_SCHEMA_DEFAULT_VERSION: Int = 2 /** @@ -206,11 +210,21 @@ class IncrementalExecution( // write out the state schema paths to the metadata file statefulOp match { case stateStoreWriter: StateStoreWriter => - val metadata = stateStoreWriter.operatorStateMetadata() - // TODO: Populate metadata with stateSchemaPaths if metadata version is v2 - val metadataWriter = new OperatorStateMetadataWriter(new Path( - checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf) - metadataWriter.write(metadata) + val metadata = stateStoreWriter.operatorStateMetadata(stateSchemaPaths) + stateStoreWriter match { + case tws: TransformWithStateExec => + val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path( + checkpointLocation, tws.getStateInfo.operatorId.toString)) + val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession, + metadataPath.toString) + operatorStateMetadataLog.add(currentBatchId, metadata) + case _ => + val metadataWriter = new OperatorStateMetadataWriter(new Path( + checkpointLocation, + stateStoreWriter.getStateInfo.operatorId.toString), + hadoopConf) + metadataWriter.write(metadata) + } case _ => } statefulOp @@ -452,11 +466,11 @@ class IncrementalExecution( new Path(checkpointLocation).getParent.toString, new SerializableConfiguration(hadoopConf)) val opMetadataList = reader.allOperatorStateMetadata - ret = opMetadataList.map { operatorMetadata => - val metadataInfoV1 = operatorMetadata - .asInstanceOf[OperatorStateMetadataV1] - .operatorInfo - metadataInfoV1.operatorId -> metadataInfoV1.operatorName + ret = opMetadataList.map { + case (OperatorStateMetadataV1(operatorInfo, _), _) => + operatorInfo.operatorId -> operatorInfo.operatorName + case (OperatorStateMetadataV2(operatorInfo, _, _), _) => + operatorInfo.operatorId -> operatorInfo.operatorName }.toMap } catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala new file mode 100644 index 0000000000000..80e334be1101f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} +import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataOutputStream + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} +import org.apache.spark.sql.internal.SQLConf + + +class OperatorStateMetadataLog( + hadoopConf: Configuration, + path: String, + metadataCacheEnabled: Boolean = false) + extends HDFSMetadataLog[OperatorStateMetadata](hadoopConf, path, metadataCacheEnabled) { + + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + + override protected def serialize(metadata: OperatorStateMetadata, out: OutputStream): Unit = { + val fsDataOutputStream = out.asInstanceOf[FSDataOutputStream] + fsDataOutputStream.write(s"v${metadata.version}\n".getBytes(StandardCharsets.UTF_8)) + metadata.version match { + case 1 => + OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata) + case 2 => + OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata) + } + } + + override protected def deserialize(in: InputStream): OperatorStateMetadata = { + // called inside a try-finally where the underlying stream is closed in the caller + // create buffered reader from input stream + val bufferedReader = new BufferedReader(new InputStreamReader(in, UTF_8)) + // read first line for version number, in the format "v{version}" + val version = bufferedReader.readLine() + version match { + case "v1" => OperatorStateMetadataV1.deserialize(bufferedReader) + case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index ed881b49ec1e9..a802ef63576f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -26,6 +26,11 @@ import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.types.{BinaryType, LongType, StructType} object TransformWithStateKeyValueRowSchema { + /** + * The following are the key/value row schema used in StateStore layer. + * Key/value rows will be serialized into Binary format in `StateTypesEncoder`. + * The "real" key/value row schema will be written into state schema metadata. + */ val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) val COMPOSITE_KEY_ROW_SCHEMA: StructType = new StructType() .add("key", BinaryType) @@ -35,6 +40,37 @@ object TransformWithStateKeyValueRowSchema { val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType() .add("value", BinaryType) .add("ttlExpirationMs", LongType) + + /** Helper functions for passing the key/value schema to write to state schema metadata. */ + + /** + * Return key schema with key column name. + */ + def getKeySchema(schema: StructType): StructType = { + new StructType().add("key", schema) + } + + /** + * Return value schema with additional TTL column if TTL is enabled. + */ + def getValueSchemaWithTTL(schema: StructType, hasTTL: Boolean): StructType = { + val valSchema = if (hasTTL) { + new StructType(schema.fields).add("ttlExpirationMs", LongType) + } else schema + new StructType() + .add("value", valSchema) + } + + /** + * Given grouping key and user key schema, return the schema of the composite key. + */ + def getCompositeKeySchema( + groupingKeySchema: StructType, + userKeySchema: StructType): StructType = { + new StructType() + .add("key", new StructType(groupingKeySchema.fields)) + .add("userKey", new StructType(userKeySchema.fields)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 893163a58a1b9..65b435b5c692c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -87,7 +87,7 @@ class StatefulProcessorHandleImpl( isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty) - extends StatefulProcessorHandleImplBase(timeMode) with Logging { + extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging { import StatefulProcessorHandleState._ /** @@ -297,8 +297,8 @@ class StatefulProcessorHandleImpl( * actually done. We need this class because we can only collect the schemas after * the StatefulProcessor is initialized. */ -class DriverStatefulProcessorHandleImpl(timeMode: TimeMode) - extends StatefulProcessorHandleImplBase(timeMode) { +class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) + extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) { private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1 @@ -309,120 +309,55 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode) def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap - /** - * Function to add the ValueState schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @tparam T - type of state variable - * @return - instance of ValueState of type T that can be used to store state persistently - */ override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getValueStateSchema(stateName, false) + getValueStateSchema(stateName, keyExprEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ValueState[T]] } - /** - * Function to add the ValueStateWithTTL schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ValueState of type T that can be used to store state persistently - */ override def getValueState[T]( stateName: String, valEncoder: Encoder[T], ttlConfig: TTLConfig): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getValueStateSchema(stateName, true) + getValueStateSchema(stateName, keyExprEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ValueState[T]] } - /** - * Function to add the ListState schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently - */ override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getListStateSchema(stateName, false) + getListStateSchema(stateName, keyExprEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ListState[T]] } - /** - * Function to add the ListStateWithTTL schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently - */ override def getListState[T]( stateName: String, valEncoder: Encoder[T], ttlConfig: TTLConfig): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getListStateSchema(stateName, true) + getListStateSchema(stateName, keyExprEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ListState[T]] } - /** - * Function to add the MapState schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * @param stateName - name of the state variable - * @param userKeyEnc - spark sql encoder for the map key - * @param valEncoder - spark sql encoder for the map value - * @tparam K - type of key for map state variable - * @tparam V - type of value for map state variable - * @return - instance of MapState of type [K,V] that can be used to store state persistently - */ override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getMapStateSchema(stateName, userKeyEnc, false) + getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] } - /** - * Function to add the MapStateWithTTL schema to the list of column family schemas. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * @param stateName - name of the state variable - * @param userKeyEnc - spark sql encoder for the map key - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam K - type of key for map state variable - * @tparam V - type of value for map state variable - * @return - instance of MapState of type [K,V] that can be used to store state persistently - */ override def getMapState[K, V]( stateName: String, userKeyEnc: Encoder[K], @@ -430,7 +365,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode) ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. - getMapStateSchema(stateName, userKeyEnc, true) + getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala index 3b952967e35d9..64d87073ccf9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImplBase.scala @@ -16,13 +16,14 @@ */ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.NoTime import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.{INITIALIZED, PRE_INIT, StatefulProcessorHandleState, TIMER_PROCESSED} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.streaming.{StatefulProcessorHandle, TimeMode} -abstract class StatefulProcessorHandleImplBase(timeMode: TimeMode) - extends StatefulProcessorHandle { +abstract class StatefulProcessorHandleImplBase( + timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) extends StatefulProcessorHandle { protected var currState: StatefulProcessorHandleState = PRE_INIT diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 3d7e1900eebb8..b3b8f66cd547e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -227,10 +227,12 @@ case class StreamingSymmetricHashJoinExec( private val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray + val stateStoreInfo = + stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } @@ -249,8 +251,7 @@ case class StreamingSymmetricHashJoinExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int - ): Array[String] = { + stateSchemaVersion: Int): Array[String] = { var result: Map[String, (StructType, StructType)] = Map.empty // get state schema for state stores on left side of the join result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(LeftSide, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index acf46df9cc1fa..42dea5d7c2452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.JString +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -93,13 +97,15 @@ case class TransformWithStateExec( } } + override def operatorStateMetadataVersion: Int = 2 + /** * We initialize this processor handle in the driver to run the init function * and fetch the schemas of the state variables initialized in this processor. * @return a new instance of the driver processor handle */ - private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = { - val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode) + private def getDriverProcessorHandle(): DriverStatefulProcessorHandleImpl = { + val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode, keyEncoder) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) statefulProcessor.init(outputMode, timeMode) @@ -111,12 +117,16 @@ case class TransformWithStateExec( * after init is called. */ private def getColFamilySchemas(): Map[String, ColumnFamilySchema] = { - val driverProcessorHandle = getDriverProcessorHandle - val columnFamilySchemas = driverProcessorHandle.getColumnFamilySchemas + val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas closeProcessorHandle() columnFamilySchemas } + /** + * This method is used for the driver-side stateful processor after we + * have collected all the necessary schemas. + * This instance of the stateful processor won't be used again. + */ private def closeProcessorHandle(): Unit = { statefulProcessor.close() statefulProcessor.setHandle(null) @@ -370,20 +380,37 @@ case class TransformWithStateExec( ) } + private def fetchOperatorStateMetadataLog( + hadoopConf: Configuration, + checkpointDir: String, + operatorId: Long): OperatorStateMetadataLog = { + val checkpointPath = new Path(checkpointDir, operatorId.toString) + val operatorStateMetadataPath = OperatorStateMetadataV2.metadataFilePath(checkpointPath) + new OperatorStateMetadataLog(hadoopConf, operatorStateMetadataPath.toString) + } + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int): - Array[String] = { + stateSchemaVersion: Int): Array[String] = { assert(stateSchemaVersion >= 3) - val newColumnFamilySchemas = getColFamilySchemas() + val newSchemas = getColFamilySchemas() val schemaFile = new StateSchemaV3File( - hadoopConf, stateSchemaFilePath(StateStoreId.DEFAULT_STORE_NAME).toString) + hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString) // TODO: Read the schema path from the OperatorStateMetadata file // and validate it with the new schema + val operatorStateMetadataLog = fetchOperatorStateMetadataLog( + hadoopConf, getStateInfo.checkpointLocation, getStateInfo.operatorId) + val mostRecentLog = operatorStateMetadataLog.getLatest() + val oldSchemas = mostRecentLog.map(_._2.asInstanceOf[OperatorStateMetadataV2]) + .map(_.stateStoreInfo.map(_.stateSchemaFilePath)).getOrElse(Array.empty) + .flatMap { schemaPath => + schemaFile.getWithPath(new Path(schemaPath)) + }.toList + validateSchemas(oldSchemas, newSchemas) // Write the new schema to the schema file - val schemaPath = schemaFile.addWithUUID(batchId, newColumnFamilySchemas.values.toList) + val schemaPath = schemaFile.addWithUUID(batchId, newSchemas.values.toList) Array(schemaPath.toString) } @@ -402,7 +429,7 @@ case class TransformWithStateExec( } } - private def stateSchemaFilePath(storeName: String): Path = { + private def stateSchemaDirPath(storeName: String): Path = { assert(storeName == StateStoreId.DEFAULT_STORE_NAME) def stateInfo = getStateInfo val stateCheckpointPath = @@ -413,6 +440,24 @@ case class TransformWithStateExec( new Path(new Path(storeNamePath, "_metadata"), "schema") } + /** Metadata of this stateful operator and its states stores. */ + override def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { + val info = getStateInfo + val operatorInfo = OperatorInfoV1(info.operatorId, shortName) + // stateSchemaFilePath should be populated at this point + val stateStoreInfo = + Array(StateStoreMetadataV2( + StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head)) + + val operatorPropertiesJson: JValue = + ("timeMode" -> JString(timeMode.toString)) ~ + ("outputMode" -> JString(outputMode.toString)) + + val json = compact(render(operatorPropertiesJson)) + OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 8ce883038401d..a319f93fe82b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, Metadata /** * Metadata for a state store instance. */ -trait StateStoreMetadata { +trait StateStoreMetadata extends Serializable { def storeName: String def numColsPrefixKey: Int def numPartitions: Int @@ -42,18 +42,35 @@ trait StateStoreMetadata { case class StateStoreMetadataV1(storeName: String, numColsPrefixKey: Int, numPartitions: Int) extends StateStoreMetadata +case class StateStoreMetadataV2( + storeName: String, + numColsPrefixKey: Int, + numPartitions: Int, + stateSchemaFilePath: String) + extends StateStoreMetadata with Serializable + +object StateStoreMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[StateStoreMetadataV2](implicitly[ClassTag[StateStoreMetadataV2]].runtimeClass) +} + /** * Information about a stateful operator. */ -trait OperatorInfo { +trait OperatorInfo extends Serializable { def operatorId: Long def operatorName: String } case class OperatorInfoV1(operatorId: Long, operatorName: String) extends OperatorInfo -trait OperatorStateMetadata { +trait OperatorStateMetadata extends Serializable { def version: Int + + def operatorInfo: OperatorInfo } case class OperatorStateMetadataV1( @@ -62,6 +79,13 @@ case class OperatorStateMetadataV1( override def version: Int = 1 } +case class OperatorStateMetadataV2( + operatorInfo: OperatorInfoV1, + stateStoreInfo: Array[StateStoreMetadataV2], + operatorPropertiesJson: String) extends OperatorStateMetadata { + override def version: Int = 2 +} + object OperatorStateMetadataV1 { private implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -84,6 +108,27 @@ object OperatorStateMetadataV1 { } } +object OperatorStateMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) + + def metadataFilePath(stateCheckpointPath: Path): Path = + new Path(new Path(new Path(stateCheckpointPath, "_metadata"), "metadata"), "v2") + + def deserialize(in: BufferedReader): OperatorStateMetadata = { + Serialization.read[OperatorStateMetadataV2](in) + } + + def serialize( + out: FSDataOutputStream, + operatorStateMetadata: OperatorStateMetadata): Unit = { + Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV2], out) + } +} + /** * Write OperatorStateMetadata into the state checkpoint directory. */ @@ -114,7 +159,9 @@ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configu } /** - * Read OperatorStateMetadata from the state checkpoint directory. + * Read OperatorStateMetadata from the state checkpoint directory. This class will only be + * used to read OperatorStateMetadataV1. + * OperatorStateMetadataV2 will be read by the OperatorStateMetadataLog. */ class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configuration) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 40c2b403b04e5..0a8021ab3de2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -73,7 +73,7 @@ object ColumnFamilySchemaV1 { s"Expected Map but got ${colFamilyMap.getClass}") val keySchema = StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]) val valueSchema = StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]) - new ColumnFamilySchemaV1( + ColumnFamilySchemaV1( colFamilyMap("columnFamilyName").asInstanceOf[String], keySchema, valueSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index f7c58a4d4b752..8aabc0846fe61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -113,7 +113,6 @@ class StateSchemaCompatibilityChecker( object StateSchemaCompatibilityChecker extends Logging { val VERSION = 2 - /** * Function to check if new state store schema is compatible with the existing schema. * @param oldSchema - old state schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala index 5cee6eb807c4f..a73bf6ff5a320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers * The StateSchemaV3File is used to write the schema of multiple column families. * Right now, this is primarily used for the TransformWithState operator, which supports * multiple column families to keep the data for multiple state variables. + * We only expect ColumnFamilySchemaV1 to be written and read from this file. * @param hadoopConf Hadoop configuration that is used to read / write metadata files. * @param path Path to the directory that will be used for writing metadata. */ @@ -40,8 +41,6 @@ class StateSchemaV3File( hadoopConf: Configuration, path: String) { - val VERSION = 3 - val metadataPath = new Path(path) protected val fileManager: CheckpointFileManager = @@ -51,7 +50,7 @@ class StateSchemaV3File( fileManager.mkdirs(metadataPath) } - def deserialize(in: InputStream): List[ColumnFamilySchema] = { + private[sql] def deserialize(in: InputStream): List[ColumnFamilySchema] = { val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { @@ -59,13 +58,13 @@ class StateSchemaV3File( } val version = lines.next().trim - validateVersion(version, VERSION) + validateVersion(version, StateSchemaV3File.VERSION) lines.map(ColumnFamilySchemaV1.fromJson).toList } - def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { - out.write(s"v${VERSION}".getBytes(UTF_8)) + private[sql] def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { + out.write(s"v${StateSchemaV3File.VERSION}".getBytes(UTF_8)) out.write('\n') out.write(schemas.map(_.json).mkString("\n").getBytes(UTF_8)) } @@ -85,7 +84,6 @@ class StateSchemaV3File( protected def write( batchMetadataFile: Path, fn: OutputStream => Unit): Unit = { - // Only write metadata when the batch has not yet been written val output = fileManager.createAtomic(batchMetadataFile, overwriteIfPossible = false) try { fn(output) @@ -101,3 +99,7 @@ class StateSchemaV3File( new Path(metadataPath, batchId.toString) } } + +object StateSchemaV3File { + val VERSION = 3 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 41167a6c917d7..387d0aeaab24c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} @@ -292,14 +293,14 @@ object KeyStateEncoderSpec { // match on type m("keyStateEncoderType").asInstanceOf[String] match { case "NoPrefixKeyStateEncoderSpec" => - NoPrefixKeyStateEncoderSpec(keySchema) + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA) case "RangeKeyScanStateEncoderSpec" => val orderingOrdinals = m("orderingOrdinals"). - asInstanceOf[List[_]].map(_.asInstanceOf[Int]) + asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt) RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) case "PrefixKeyScanStateEncoderSpec" => - val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int] - PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) + val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt] + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey.toInt) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 94d976b568a5e..492ba37865d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -73,6 +74,12 @@ trait StatefulOperator extends SparkPlan { } } + def metadataFilePath(): Path = { + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + } + // Function used to record state schema for the first time and validate it against proposed // schema changes in the future. Runs as part of a planning rule on the driver. // Returns the schema file path for operators that write this to the metadata file, @@ -141,6 +148,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp */ def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = Some(inputWatermarkMs) + def operatorStateMetadataVersion: Int = 1 + override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, @@ -156,6 +165,20 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp "number of state store instances") ) ++ stateStoreCustomMetrics ++ pythonMetrics + def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } + /** * Get the progress made by this stateful operator after execution. This should be called in * the driver after this SparkPlan has been executed and metrics have been updated. @@ -189,7 +212,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** Metadata of this stateful operator and its states stores. */ - def operatorStateMetadata(): OperatorStateMetadata = { + def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = @@ -910,7 +934,8 @@ case class SessionWindowStateStoreSaveExec( keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = Array(StateStoreMetadataV1( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 98b2030f1bac4..2c4111ec026ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -30,6 +30,8 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -1627,6 +1629,30 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] keyRow, keySchema, valueRow, keySchema, storeConf) } + test("test serialization and deserialization of NoPrefixKeyStateEncoderSpec") { + implicit val formats: DefaultFormats.type = DefaultFormats + val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]] + val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap) + assert(encoderSpec == deserializedEncoderSpec) + } + + test("test serialization and deserialization of PrefixKeyScanStateEncoderSpec") { + implicit val formats: DefaultFormats.type = DefaultFormats + val encoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1) + val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]] + val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap) + assert(encoderSpec == deserializedEncoderSpec) + } + + test("test serialization and deserialization of RangeKeyScanStateEncoderSpec") { + implicit val formats: DefaultFormats.type = DefaultFormats + val encoderSpec = RangeKeyScanStateEncoderSpec(keySchema, Seq(1)) + val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]] + val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap) + assert(encoderSpec == deserializedEncoderSpec) + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 5e408dc999f82..7954ee510ca1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.streaming import java.io.File import java.util.UUID +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorInfoV1, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMetadataV2, StateStoreMultipleColumnFamiliesNotSupportedException, StateStoreValueSchemaNotCompatible, TestClass} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 @@ -61,6 +64,32 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S } } +class RunningCountStatefulProcessorInt extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient protected var _countState: ValueState[Int] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + // Class to verify stateful processor usage with adding processing time timers class RunningCountStatefulProcessorWithProcTimeTimer extends RunningCountStatefulProcessor { private def handleProcessingTimeBasedTimers( @@ -306,6 +335,21 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess } } +class StatefulProcessorWithCompositeTypes extends RunningCountStatefulProcessor { + @transient private var _listState: ListState[TestClass] = _ + @transient private var _mapState: MapState[String, POJOTestClass] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) + _listState = getHandle.getListState[TestClass]( + "listState", Encoders.product[TestClass]) + _mapState = getHandle.getMapState[String, POJOTestClass]( + "mapState", Encoders.STRING, Encoders.bean(classOf[POJOTestClass])) + } +} + /** * Class that adds tests for transformWithState stateful streaming operator */ @@ -786,6 +830,187 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + private def fetchOperatorStateMetadataLog( + checkpointDir: String, + operatorId: Int): OperatorStateMetadataLog = { + val hadoopConf = spark.sessionState.newHadoopConf() + val stateChkptPath = new Path(checkpointDir, s"state/$operatorId") + val operatorStateMetadataPath = OperatorStateMetadataV2.metadataFilePath(stateChkptPath) + new OperatorStateMetadataLog(hadoopConf, operatorStateMetadataPath.toString) + } + + private def fetchColumnFamilySchemas( + checkpointDir: String, + operatorId: Int): List[ColumnFamilySchema] = { + val operatorStateMetadataLog = fetchOperatorStateMetadataLog(checkpointDir, operatorId) + val stateSchemaFilePath = operatorStateMetadataLog. + getLatest().get._2. + asInstanceOf[OperatorStateMetadataV2]. + stateStoreInfo.head.stateSchemaFilePath + fetchStateSchemaV3File(checkpointDir, operatorId).getWithPath(new Path(stateSchemaFilePath)) + } + + private def fetchStateSchemaV3File( + checkpointDir: String, + operatorId: Int): StateSchemaV3File = { + val hadoopConf = spark.sessionState.newHadoopConf() + val stateChkptPath = new Path(checkpointDir, s"state/$operatorId") + val stateSchemaPath = new Path(new Path(stateChkptPath, "_metadata"), "schema") + new StateSchemaV3File(hadoopConf, stateSchemaPath.toString) + } + + test("transformWithState - verify StateSchemaV3 file is written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { chkptDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val columnFamilySchemas = fetchColumnFamilySchemas(chkptDir.getCanonicalPath, 0) + assert(columnFamilySchemas.length == 1) + val expected = ColumnFamilySchemaV1( + "countState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val actual = columnFamilySchemas.head.asInstanceOf[ColumnFamilySchemaV1] + assert(expected == actual) + } + } + } + + test("transformWithState - verify StateSchemaV3 file is written correctly," + + " multiple column families") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] + val result = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + StopStream + ) + + val columnFamilySchemas = fetchColumnFamilySchemas(chkptDir.getCanonicalPath, 0) + assert(columnFamilySchemas.length == 2) + + val expected = List( + ColumnFamilySchemaV1( + "countState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ), + ColumnFamilySchemaV1( + "mostRecent", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", StringType, true)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + ) + val actual = columnFamilySchemas.map(_.asInstanceOf[ColumnFamilySchemaV1]) + assert(expected == actual) + } + } + } + + test("transformWithState - verify that OperatorStateMetadataV2" + + " file is being written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq( + Row(0, "transformWithStateExec", "default", 5, 0L, 0L), + Row(0, "transformWithStateExec", "default", 5, 1L, 1L) + )) + checkAnswer(df.select(df.metadataColumn("_operatorProperties")), + Seq( + Row("""{"timeMode":"NoTime","outputMode":"Update"}"""), + Row("""{"timeMode":"NoTime","outputMode":"Update"}""") + ) + ) + } + } + } + + test("transformWithState - verify OperatorStateMetadataV2 serialization and deserialization" + + " works") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val metadata = OperatorStateMetadataV2( + OperatorInfoV1(0, "transformWithStateExec"), + Array(StateStoreMetadataV2("default", 0, 0, "path")), + """{"timeMode":"NoTime","outputMode":"Update"}""") + val metadataLog = new OperatorStateMetadataLog( + spark.sessionState.newHadoopConf(), + checkpointDir.getCanonicalPath) + metadataLog.add(0, metadata) + assert(metadataLog.get(0).isDefined) + // assert that each of the fields are the same + val metadataV2 = metadataLog.get(0).get.asInstanceOf[OperatorStateMetadataV2] + assert(metadataV2.operatorInfo == metadata.operatorInfo) + assert(metadataV2.stateStoreInfo.sameElements(metadata.stateStoreInfo)) + assert(metadataV2.operatorPropertiesJson == metadata.operatorPropertiesJson) + } + } + } + test("transformWithState - verify StateSchemaV3 serialization and deserialization" + " works with one batch") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> @@ -795,8 +1020,10 @@ class TransformWithStateSuite extends StateStoreMetricsTest withTempDir { checkpointDir => val schema = List(ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None )) @@ -820,8 +1047,10 @@ class TransformWithStateSuite extends StateStoreMetricsTest val schema0 = List(ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None )) @@ -829,15 +1058,19 @@ class TransformWithStateSuite extends StateStoreMetricsTest val schema1 = List( ColumnFamilySchemaV1( "countState", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None ), ColumnFamilySchemaV1( "mostRecent", - KEY_ROW_SCHEMA, - VALUE_ROW_SCHEMA, + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", StringType, false)), NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), None ) @@ -858,6 +1091,43 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + test("test that invalid schema evolution fails query for column family") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreValueSchemaNotCompatible] { + (t: Throwable) => { + assert(t.getMessage.contains("Please check number and type of fields.")) + } + } + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest { @@ -899,3 +1169,88 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { ) } } + +class TransformWithStateSchemaSuite extends StateStoreMetricsTest { + + import testImplicits._ + + test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val metadataPathPostfix = "state/0/default/_metadata" + val stateSchemaPath = new Path(checkpointDir.toString, + s"$metadataPathPostfix/schema/0") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) + + val schema0 = ColumnFamilySchemaV1( + "countState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema1 = ColumnFamilySchemaV1( + "listState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("id", LongType, false) + .add("name", StringType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema2 = ColumnFamilySchemaV1( + "mapState", + new StructType() + .add("key", new StructType().add("value", StringType)) + .add("userKey", new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("id", IntegerType, false) + .add("name", StringType)), + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), + Option(new StructType().add("value", StringType)) + ) + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithCompositeTypes(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + Execute { q => + val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath + val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString, + metadataPathPostfix).toString) + val colFamilySeq = ssv3.deserialize(fm.open(schemaFilePath)) + + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) + + assert(colFamilySeq.length == 3) + assert(colFamilySeq.map(_.toString).toSet == Set( + schema0, schema1, schema2 + ).map(_.toString)) + }, + StopStream + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 54004b419f759..d4d730dac9996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql.streaming import java.time.Duration +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl, ValueStateImplWithTTL} -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaV3File} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ object TTLInputProcessFunction { def processRow( @@ -111,15 +115,15 @@ class ValueStateTTLProcessor(ttlConfig: TTLConfig) } } -case class MultipleValueStatesTTLProcessor( +class MultipleValueStatesTTLProcessor( ttlKey: String, noTtlKey: String, ttlConfig: TTLConfig) extends StatefulProcessor[String, InputEvent, OutputEvent] with Logging { - @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ - @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + @transient var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ + @transient var _valueStateWithoutTTL: ValueStateImpl[Int] = _ override def init( outputMode: OutputMode, @@ -160,6 +164,28 @@ case class MultipleValueStatesTTLProcessor( } } +class TTLProcessorWithCompositeTypes( + ttlKey: String, + noTtlKey: String, + ttlConfig: TTLConfig) + extends MultipleValueStatesTTLProcessor( + ttlKey: String, noTtlKey: String, ttlConfig: TTLConfig) { + @transient private var _listStateWithTTL: ListStateImplWithTTL[Int] = _ + @transient private var _mapStateWithTTL: MapStateImplWithTTL[Int, String] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + super.init(outputMode, timeMode) + _listStateWithTTL = getHandle + .getListState("listState", Encoders.scalaInt, ttlConfig) + .asInstanceOf[ListStateImplWithTTL[Int]] + _mapStateWithTTL = getHandle + .getMapState("mapState", Encoders.scalaInt, Encoders.STRING, ttlConfig) + .asInstanceOf[MapStateImplWithTTL[Int, String]] + } +} + class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { import testImplicits._ @@ -181,7 +207,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val result = inputStream.toDS() .groupByKey(x => x.key) .transformWithState( - MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig), + new MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig), TimeMode.ProcessingTime(), OutputMode.Append()) @@ -225,4 +251,107 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { ) } } + + test("verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val metadataPathPostfix = "state/0/default/_metadata" + val stateSchemaPath = new Path(checkpointDir.toString, + s"$metadataPathPostfix/schema/0") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) + + val schema0 = ColumnFamilySchemaV1( + "valueState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType().add("value", LongType, false) + .add("ttlExpirationMs", LongType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema1 = ColumnFamilySchemaV1( + "listState", + new StructType().add("key", + new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("value", IntegerType, false) + .add("ttlExpirationMs", LongType)), + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + None + ) + val schema2 = ColumnFamilySchemaV1( + "mapState", + new StructType() + .add("key", new StructType().add("value", StringType)) + .add("userKey", new StructType().add("value", StringType)), + new StructType().add("value", + new StructType() + .add("value", IntegerType, false) + .add("ttlExpirationMs", LongType)), + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), + Option(new StructType().add("value", StringType)) + ) + + val ttlKey = "k1" + val noTtlKey = "k2" + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new TTLProcessorWithCompositeTypes(ttlKey, noTtlKey, ttlConfig), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputStream, InputEvent(ttlKey, "put", 1)), + AddData(inputStream, InputEvent(noTtlKey, "put", 2)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + Execute { q => + val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath + val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString, + metadataPathPostfix).toString) + val colFamilySeq = ssv3.deserialize(fm.open(schemaFilePath)) + + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics + .get("numValueStateWithTTLVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics + .get("numListStateWithTTLVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics + .get("numMapStateWithTTLVars").toInt) + + // TODO when there are two state var with the same name, + // only one schema file is preserved + assert(colFamilySeq.length == 3) + /* + assert(colFamilySeq.map(_.toString).toSet == Set( + schema0, schema1, schema2 + ).map(_.toString)) */ + + assert(colFamilySeq(1).toString == schema1.toString) + assert(colFamilySeq(2).toString == schema2.toString) + // The remaining schema file is the one without ttl + // assert(colFamilySeq.head.toString == schema0.toString) + }, + StopStream + ) + } + } + } }