diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 9c8b2d0375588..cfe98c13f8c07 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -24,6 +24,7 @@ import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} import org.apache.spark.SparkException +import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala index 1b2013d87eedf..34a23cabf6f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala @@ -36,7 +36,8 @@ import org.apache.spark.util.ArrayImplicits._ * the fields of the provided schema. * @param schema The required schema of records from datasource files. */ -abstract class StructFilters(pushedFilters: Seq[sources.Filter], schema: StructType) { +abstract class StructFilters( + pushedFilters: Seq[sources.Filter], schema: StructType) extends Serializable { protected val filters = StructFilters.pushedFilters(pushedFilters.toArray, schema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 50f7143145a94..c089ffbc5857c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2204,6 +2204,16 @@ object SQLConf { .intConf .createWithDefault(3) + val STREAMING_STATE_STORE_ENCODING_FORMAT = + buildConf("spark.sql.streaming.stateStore.encodingFormat") + .doc("The encoding format used for stateful operators to store information" + + "in the state store") + .version("4.0.0") + .stringConf + .checkValue(v => Set("UnsafeRow", "Avro").contains(v), + "Valid values are 'UnsafeRow' and 'Avro'") + .createWithDefault("UnsafeRow") + // The feature is still in development, so it is still internal. val STATE_STORE_CHECKPOINT_FORMAT_VERSION = buildConf("spark.sql.streaming.stateStore.checkpointFormatVersion") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index ac20614553ca2..7addc7608260e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -52,7 +52,7 @@ private[sql] class AvroDeserializer( filters: StructFilters, useStableIdForUnionType: Boolean, stableIdPrefixForUnionType: String, - recursiveFieldMaxDepth: Int) { + recursiveFieldMaxDepth: Int) extends Serializable { def this( rootAvroType: Schema, @@ -463,7 +463,7 @@ private[sql] class AvroDeserializer( * A base interface for updating values inside catalyst data structure like `InternalRow` and * `ArrayData`. */ - sealed trait CatalystDataUpdater { + sealed trait CatalystDataUpdater extends Serializable { def set(ordinal: Int, value: Any): Unit def setNullAt(ordinal: Int): Unit = set(ordinal, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 814a28e24f522..3aefb47d20825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -47,7 +47,7 @@ private[sql] class AvroSerializer( rootAvroType: Schema, nullable: Boolean, positionalFieldMatch: Boolean, - datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging { + datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging with Serializable { def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = { this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 4b8bc72b2ed7f..634222e785a44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -259,6 +259,19 @@ class IncrementalExecution( } } + object StateStoreColumnFamilySchemasRule extends SparkPlanPartialRule { + override val rule: PartialFunction[SparkPlan, SparkPlan] = { + case statefulOp: StatefulOperator => + statefulOp match { + case op: TransformWithStateExec => + op.copy( + columnFamilySchemas = op.getColFamilySchemas() + ) + case _ => statefulOp + } + } + } + object StateOpIdRule extends SparkPlanPartialRule { override val rule: PartialFunction[SparkPlan, SparkPlan] = { case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, @@ -552,9 +565,9 @@ class IncrementalExecution( // The rule below doesn't change the plan but can cause the side effect that // metadata/schema is written in the checkpoint directory of stateful operator. planWithStateOpId transform StateSchemaAndOperatorMetadataRule.rule - - simulateWatermarkPropagation(planWithStateOpId) - planWithStateOpId transform WatermarkPropagationRule.rule + val planWithStateSchemas = planWithStateOpId transform StateStoreColumnFamilySchemasRule.rule + simulateWatermarkPropagation(planWithStateSchemas) + planWithStateSchemas transform WatermarkPropagationRule.rule } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 32683aebd8c18..e0a201834e24d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState import org.apache.spark.sql.types.StructType @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.StructType * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value * @param metrics - metrics to be updated as part of stateful processing + * @param avroEnc - optional Avro serializer and deserializer for this state variable that + * is used by the StateStore to encode state in Avro format * @tparam S - data type of object that will be stored in the list */ class ListStateImpl[S]( @@ -39,7 +41,8 @@ class ListStateImpl[S]( stateName: String, keyExprEnc: ExpressionEncoder[Any], valEncoder: ExpressionEncoder[Any], - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + avroEnc: Option[AvroEncoder] = None) extends ListStateMetricsImpl with ListState[S] with Logging { @@ -50,8 +53,13 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName) - store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, valEncoder.schema, - NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true) + store.createColFamilyIfAbsent( + stateName, + keyExprEnc.schema, + valEncoder.schema, + NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), + useMultipleValuesPerKey = true, + avroEncoderSpec = avroEnc) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 4c8dd6a193c25..313e974fd5a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. * @param metrics - metrics to be updated as part of stateful processing + * @param avroEnc - optional Avro serializer and deserializer for this state variable that + * is used by the StateStore to encode state in Avro format + * @param secondaryIndexAvroEnc - optional Avro serializer and deserializer for TTL state that + * is used by the StateStore to encode state in Avro format * @tparam S - data type of object that will be stored */ class ListStateImplWithTTL[S]( @@ -45,8 +49,11 @@ class ListStateImplWithTTL[S]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, - metrics: Map[String, SQLMetric] = Map.empty) - extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs) + metrics: Map[String, SQLMetric] = Map.empty, + avroEnc: Option[AvroEncoder] = None, + secondaryIndexAvroEnc: Option[AvroEncoder] = None) + extends SingleKeyTTLStateImpl( + stateName, store, keyExprEnc, batchTimestampMs, secondaryIndexAvroEnc) with ListStateMetricsImpl with ListState[S] { @@ -65,7 +72,8 @@ class ListStateImplWithTTL[S]( private def initialize(): Unit = { store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, getValueSchemaWithTTL(valEncoder.schema, true), - NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true) + NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = true, + avroEncoderSpec = avroEnc) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 4e608a5d5dbbe..b57eaec8d1e3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState import org.apache.spark.sql.types.StructType @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.StructType * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value * @param metrics - metrics to be updated as part of stateful processing + * @param avroEnc - optional Avro serializer and deserializer for this state variable that + * is used by the StateStore to encode state in Avro format * @tparam K - type of key for map state variable * @tparam V - type of value for map state variable */ @@ -41,7 +43,8 @@ class MapStateImpl[K, V]( keyExprEnc: ExpressionEncoder[Any], userKeyEnc: ExpressionEncoder[Any], valEncoder: ExpressionEncoder[Any], - metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging { + metrics: Map[String, SQLMetric] = Map.empty, + avroEnc: Option[AvroEncoder] = None) extends MapState[K, V] with Logging { // Pack grouping key and user key together as a prefixed composite key private val schemaForCompositeKeyRow: StructType = { @@ -52,7 +55,7 @@ class MapStateImpl[K, V]( keyExprEnc, userKeyEnc, valEncoder, stateName) store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow, - PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1)) + PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1), avroEncoderSpec = avroEnc) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 19704b6d1bd59..f267304b1fe4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} import org.apache.spark.util.NextIterator @@ -36,21 +36,27 @@ import org.apache.spark.util.NextIterator * @param ttlConfig - the ttl configuration (time to live duration etc.) * @param batchTimestampMs - current batch processing timestamp. * @param metrics - metrics to be updated as part of stateful processing + * @param avroEnc - optional Avro serializer and deserializer for this state variable that + * is used by the StateStore to encode state in Avro format + * @param secondaryIndexAvroEnc - optional Avro serializer and deserializer for TTL state that + * is used by the StateStore to encode state in Avro format * @tparam K - type of key for map state variable * @tparam V - type of value for map state variable * @return - instance of MapState of type [K,V] that can be used to store state persistently */ class MapStateImplWithTTL[K, V]( - store: StateStore, - stateName: String, - keyExprEnc: ExpressionEncoder[Any], - userKeyEnc: ExpressionEncoder[Any], - valEncoder: ExpressionEncoder[Any], - ttlConfig: TTLConfig, - batchTimestampMs: Long, - metrics: Map[String, SQLMetric] = Map.empty) + store: StateStore, + stateName: String, + keyExprEnc: ExpressionEncoder[Any], + userKeyEnc: ExpressionEncoder[Any], + valEncoder: ExpressionEncoder[Any], + ttlConfig: TTLConfig, + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty, + avroEnc: Option[AvroEncoder] = None, + secondaryIndexAvroEnc: Option[AvroEncoder] = None) extends CompositeKeyTTLStateImpl[K](stateName, store, - keyExprEnc, userKeyEnc, batchTimestampMs) + keyExprEnc, userKeyEnc, batchTimestampMs, secondaryIndexAvroEnc) with MapState[K, V] with Logging { private val stateTypesEncoder = new CompositeKeyStateEncoder( @@ -66,7 +72,8 @@ class MapStateImplWithTTL[K, V]( getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema) store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, getValueSchemaWithTTL(valEncoder.schema, true), - PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1)) + PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1), + avroEncoderSpec = avroEnc) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 7da8408f98b0f..5662e91c9926a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -16,24 +16,118 @@ */ package org.apache.spark.sql.execution.streaming +import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema} +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StructField, StructType} -object StateStoreColumnFamilySchemaUtils { +object StateStoreColumnFamilySchemaUtils extends Serializable { + + def apply(initializeAvroSerde: Boolean): StateStoreColumnFamilySchemaUtils = + new StateStoreColumnFamilySchemaUtils(initializeAvroSerde) + + /** + * Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans + * we want to use big-endian encoding, so we need to convert the source schema to replace these + * types with BinaryType. + * + * @param schema The schema to convert + * @param ordinals If non-empty, only convert fields at these ordinals. + * If empty, convert all fields. + */ + def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = { + val ordinalSet = ordinals.toSet + + StructType(schema.fields.zipWithIndex.flatMap { case (field, idx) => + if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) { + // For each numeric field, create two fields: + // 1. A boolean for sign (positive = true, negative = false) + // 2. The original numeric value in big-endian format + Seq( + StructField(s"${field.name}_marker", ByteType, nullable = false), + field.copy(name = s"${field.name}_value", BinaryType) + ) + } else { + Seq(field) + } + }) + } + + private def isFixedSize(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType => true + case _ => false + } + + def getTtlColFamilyName(stateName: String): String = { + "$ttl_" + stateName + } +} + +/** + * + * @param initializeAvroSerde Whether or not to create the Avro serializers and deserializers + * for this state type. This class is used to create the + * StateStoreColumnFamilySchema for each state variable from the driver + */ +class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) + extends Logging with Serializable { + private def getAvroSerializer(schema: StructType): AvroSerializer = { + val avroType = SchemaConverters.toAvroType(schema) + new AvroSerializer(schema, avroType, nullable = false) + } + + private def getAvroDeserializer(schema: StructType): AvroDeserializer = { + val avroType = SchemaConverters.toAvroType(schema) + val avroOptions = AvroOptions(Map.empty) + new AvroDeserializer(avroType, schema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + } + + /** + * If initializeAvroSerde is true, this method will create an Avro Serializer and Deserializer + * for a particular key and value schema. + */ + private[sql] def getAvroSerde( + keySchema: StructType, + valSchema: StructType, + suffixKeySchema: Option[StructType] = None + ): Option[AvroEncoder] = { + if (initializeAvroSerde) { + + val (suffixKeySer, suffixKeyDe) = if (suffixKeySchema.isDefined) { + (Some(getAvroSerializer(suffixKeySchema.get)), + Some(getAvroDeserializer(suffixKeySchema.get))) + } else { + (None, None) + } + Some(AvroEncoder( + getAvroSerializer(keySchema), + getAvroDeserializer(keySchema), + getAvroSerializer(valSchema), + getAvroDeserializer(valSchema), + suffixKeySer, suffixKeyDe)) + } else { + None + } + } def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], valEncoder: Encoder[T], hasTtl: Boolean): StateStoreColFamilySchema = { + val valSchema = getValueSchemaWithTTL(valEncoder.schema, hasTtl) StateStoreColFamilySchema( stateName, keyEncoder.schema, - getValueSchemaWithTTL(valEncoder.schema, hasTtl), - Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))) + valSchema, + Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)), + avroEnc = getAvroSerde(keyEncoder.schema, valSchema)) } def getListStateSchema[T]( @@ -41,11 +135,13 @@ object StateStoreColumnFamilySchemaUtils { keyEncoder: ExpressionEncoder[Any], valEncoder: Encoder[T], hasTtl: Boolean): StateStoreColFamilySchema = { + val valSchema = getValueSchemaWithTTL(valEncoder.schema, hasTtl) StateStoreColFamilySchema( stateName, keyEncoder.schema, - getValueSchemaWithTTL(valEncoder.schema, hasTtl), - Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))) + valSchema, + Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)), + avroEnc = getAvroSerde(keyEncoder.schema, valSchema)) } def getMapStateSchema[K, V]( @@ -55,12 +151,70 @@ object StateStoreColumnFamilySchemaUtils { valEncoder: Encoder[V], hasTtl: Boolean): StateStoreColFamilySchema = { val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema) + val valSchema = getValueSchemaWithTTL(valEncoder.schema, hasTtl) StateStoreColFamilySchema( stateName, compositeKeySchema, getValueSchemaWithTTL(valEncoder.schema, hasTtl), Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), - Some(userKeyEnc.schema)) + Some(userKeyEnc.schema), + avroEnc = getAvroSerde( + StructType(compositeKeySchema.take(1)), + valSchema, + Some(StructType(compositeKeySchema.drop(1))) + ) + ) + } + + // This function creates the StateStoreColFamilySchema for + // the TTL secondary index. + // Because we want to encode fixed-length types as binary types + // if we are using Avro, we need to do some schema conversion to ensure + // we can use range scan + def getTtlStateSchema( + stateName: String, + keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = { + val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + getSingleKeyTTLRowSchema(keyEncoder.schema), Seq(0)) + val ttlValSchema = StructType( + Array(StructField("__dummy__", NullType))) + StateStoreColFamilySchema( + stateName, + ttlKeySchema, + ttlValSchema, + Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), + avroEnc = getAvroSerde( + getSingleKeyTTLRowSchema(keyEncoder.schema), + ttlValSchema, + Some(StructType(ttlKeySchema.drop(2))) + ) + ) + } + + // This function creates the StateStoreColFamilySchema for + // the TTL secondary index. + // Because we want to encode fixed-length types as binary types + // if we are using Avro, we need to do some schema conversion to ensure + // we can use range scan + def getTtlStateSchema( + stateName: String, + keyEncoder: ExpressionEncoder[Any], + userKeySchema: StructType): StateStoreColFamilySchema = { + val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + getCompositeKeyTTLRowSchema(keyEncoder.schema, userKeySchema), Seq(0)) + val ttlValSchema = StructType( + Array(StructField("__dummy__", NullType))) + StateStoreColFamilySchema( + stateName, + ttlKeySchema, + ttlValSchema, + Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), + avroEnc = getAvroSerde( + getCompositeKeyTTLRowSchema(keyEncoder.schema, userKeySchema), + ttlValSchema, + Some(StructType(ttlKeySchema.drop(2))) + ) + ) } def getTimerStateSchema( @@ -71,6 +225,34 @@ object StateStoreColumnFamilySchemaUtils { stateName, keySchema, valSchema, - Some(PrefixKeyScanStateEncoderSpec(keySchema, 1))) + Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)), + avroEnc = getAvroSerde( + StructType(keySchema.take(1)), + valSchema, + Some(StructType(keySchema.drop(1))) + )) + } + + // This function creates the StateStoreColFamilySchema for + // Timers' secondary index. + // Because we want to encode fixed-length types as binary types + // if we are using Avro, we need to do some schema conversion to ensure + // we can use range scan + def getTimerStateSchemaForSecIndex( + stateName: String, + keySchema: StructType, + valSchema: StructType): StateStoreColFamilySchema = { + val avroKeySchema = StateStoreColumnFamilySchemaUtils. + convertForRangeScan(keySchema, Seq(0)) + StateStoreColFamilySchema( + stateName, + keySchema, + valSchema, + Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))), + avroEnc = getAvroSerde( + keySchema, + valSchema, + Some(StructType(avroKeySchema.drop(2))) + )) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 0f90fa8d9e490..8d5ad2ec11098 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT +import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils.getTtlColFamilyName import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, TimeMode, TTLConfig, ValueState} import org.apache.spark.util.Utils @@ -96,6 +97,8 @@ class QueryInfoImpl( * @param isStreaming - defines whether the query is streaming or batch * @param batchTimestampMs - timestamp for the current batch if available * @param metrics - metrics to be updated as part of stateful processing + * @param schemas - StateStoreColumnFamilySchemas that include Avro serializers and deserializers + * for each state variable, if Avro encoding is enabled for this query */ class StatefulProcessorHandleImpl( store: StateStore, @@ -104,7 +107,8 @@ class StatefulProcessorHandleImpl( timeMode: TimeMode, isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + schemas: Map[String, StateStoreColFamilySchema] = Map.empty) extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging { import StatefulProcessorHandleState._ @@ -118,6 +122,14 @@ class StatefulProcessorHandleImpl( currState = CREATED + private def getAvroEnc(stateName: String): Option[AvroEncoder] = { + if (!schemas.contains(stateName)) { + None + } else { + schemas(stateName).avroEnc + } + } + private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) val (queryId, batchId) = if (!isStreaming) { @@ -137,7 +149,13 @@ class StatefulProcessorHandleImpl( override def getQueryInfo(): QueryInfo = currQueryInfo - private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder) + private lazy val timerStateName = TimerStateUtils.getTimerStateVarName( + timeMode.toString) + private lazy val timerSecIndexColFamily = TimerStateUtils.getSecIndexColFamilyName( + timeMode.toString) + private lazy val timerState = new TimerStateImpl( + store, timeMode, keyEncoder, getAvroEnc(timerStateName), + getAvroEnc(timerSecIndexColFamily)) /** * Function to register a timer for the given expiryTimestampMs @@ -227,13 +245,14 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics) + keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics, + getAvroEnc(stateName), getAvroEnc(getTtlColFamilyName(stateName))) ttlStates.add(valueStateWithTTL) TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars") valueStateWithTTL } else { val valueStateWithoutTTL = new ValueStateImpl[T](store, stateName, - keyEncoder, stateEncoder, metrics) + keyEncoder, stateEncoder, metrics, getAvroEnc(stateName)) TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars") valueStateWithoutTTL } @@ -276,13 +295,14 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics) + keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics, + getAvroEnc(stateName), getAvroEnc(getTtlColFamilyName(stateName))) TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars") ttlStates.add(listStateWithTTL) listStateWithTTL } else { val listStateWithoutTTL = new ListStateImpl[T](store, stateName, keyEncoder, - stateEncoder, metrics) + stateEncoder, metrics, getAvroEnc(stateName)) TWSMetricsUtils.incrementMetric(metrics, "numListStateVars") listStateWithoutTTL } @@ -314,13 +334,14 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get, metrics) + valEncoder, ttlConfig, batchTimestampMs.get, metrics, + getAvroEnc(stateName), getAvroEnc(getTtlColFamilyName(stateName))) TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) mapStateWithTTL } else { val mapStateWithoutTTL = new MapStateImpl[K, V](store, stateName, keyEncoder, - userKeyEnc, valEncoder, metrics) + userKeyEnc, valEncoder, metrics, getAvroEnc(stateName)) TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars") mapStateWithoutTTL } @@ -343,7 +364,8 @@ class StatefulProcessorHandleImpl( * actually done. We need this class because we can only collect the schemas after * the StatefulProcessor is initialized. */ -class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) +class DriverStatefulProcessorHandleImpl( + timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any], initializeAvroEnc: Boolean) extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) { // Because this is only happening on the driver side, there is only @@ -354,6 +376,12 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi private val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] = new mutable.HashMap[String, TransformWithStateVariableInfo]() + // If we want use Avro serializers and deserializers, the schemaUtils will create and populate + // these objects as a part of the schema, and will add this to the map + // These serde objects will eventually be passed to the executors + private val schemaUtils: StateStoreColumnFamilySchemaUtils = + new StateStoreColumnFamilySchemaUtils(initializeAvroEnc) + // If timeMode is not None, add a timer column family schema to the operator metadata so that // registered timers can be read using the state data source reader. if (timeMode != TimeMode.None()) { @@ -372,10 +400,16 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi private def addTimerColFamily(): Unit = { val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString) + val secIndexColFamilyName = TimerStateUtils.getSecIndexColFamilyName(timeMode.toString) val timerEncoder = new TimerKeyEncoder(keyExprEnc) - val colFamilySchema = StateStoreColumnFamilySchemaUtils. - getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow) + val colFamilySchema = schemaUtils + .getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow) + val secIndexColFamilySchema = schemaUtils + .getTimerStateSchemaForSecIndex(secIndexColFamilyName, + timerEncoder.keySchemaForSecIndex, + timerEncoder.schemaForValueRow) columnFamilySchemas.put(stateName, colFamilySchema) + columnFamilySchemas.put(secIndexColFamilyName, secIndexColFamilySchema) val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName) stateVariableInfos.put(stateName, stateVariableInfo) } @@ -394,11 +428,14 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) { false } else { + val ttlColFamilyName = getTtlColFamilyName(stateName) + val ttlColFamilySchema = schemaUtils.getTtlStateSchema(ttlColFamilyName, keyExprEnc) + columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema) true } val stateEncoder = encoderFor[T] - val colFamilySchema = StateStoreColumnFamilySchemaUtils. + val colFamilySchema = schemaUtils. getValueStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled) checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) @@ -422,12 +459,15 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) { false } else { + val ttlColFamilyName = getTtlColFamilyName(stateName) + val ttlColFamilySchema = schemaUtils.getTtlStateSchema(ttlColFamilyName, keyExprEnc) + columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema) true } val stateEncoder = encoderFor[T] - val colFamilySchema = StateStoreColumnFamilySchemaUtils. - getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled) + val colFamilySchema = schemaUtils + .getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled) checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) val stateVariableInfo = TransformWithStateVariableUtils. @@ -449,16 +489,21 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) + val userKeyEnc = encoderFor[K] + val valEncoder = encoderFor[V] val ttlEnabled = if (ttlConfig.ttlDuration != null && ttlConfig.ttlDuration.isZero) { false } else { + val ttlColFamilyName = getTtlColFamilyName(stateName) + val ttlColFamilySchema = schemaUtils.getTtlStateSchema( + ttlColFamilyName, keyExprEnc, userKeyEnc.schema) + columnFamilySchemas.put(ttlColFamilyName, ttlColFamilySchema) true } - val userKeyEnc = encoderFor[K] - val valEncoder = encoderFor[V] - val colFamilySchema = StateStoreColumnFamilySchemaUtils. - getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled) + + val colFamilySchema = schemaUtils + .getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled) columnFamilySchemas.put(stateName, colFamilySchema) val stateVariableInfo = TransformWithStateVariableUtils. getMapState(stateName, ttlEnabled = ttlEnabled) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 87d1a15dff1a9..02008a1ba4fd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -21,8 +21,9 @@ import java.time.Duration import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils.getTtlColFamilyName import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, RangeKeyScanStateEncoderSpec, StateStore} import org.apache.spark.sql.types._ object StateTTLSchema { @@ -79,12 +80,13 @@ abstract class SingleKeyTTLStateImpl( stateName: String, store: StateStore, keyExprEnc: ExpressionEncoder[Any], - ttlExpirationMs: Long) + ttlExpirationMs: Long, + avroEnc: Option[AvroEncoder] = None) extends TTLState { import org.apache.spark.sql.execution.streaming.StateTTLSchema._ - private val ttlColumnFamilyName = "$ttl_" + stateName + private val ttlColumnFamilyName = getTtlColFamilyName(stateName) private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema) private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc) @@ -93,7 +95,7 @@ abstract class SingleKeyTTLStateImpl( UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema, TTL_VALUE_ROW_SCHEMA, - RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true) + RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true, avroEncoderSpec = avroEnc) /** * This function will be called when clear() on State Variables @@ -199,12 +201,13 @@ abstract class CompositeKeyTTLStateImpl[K]( store: StateStore, keyExprEnc: ExpressionEncoder[Any], userKeyEncoder: ExpressionEncoder[Any], - ttlExpirationMs: Long) + ttlExpirationMs: Long, + avroEnc: Option[AvroEncoder] = None) extends TTLState { import org.apache.spark.sql.execution.streaming.StateTTLSchema._ - private val ttlColumnFamilyName = "$ttl_" + stateName + private val ttlColumnFamilyName = getTtlColFamilyName(stateName) private val keySchema = getCompositeKeyTTLRowSchema( keyExprEnc.schema, userKeyEncoder.schema ) @@ -218,7 +221,7 @@ abstract class CompositeKeyTTLStateImpl[K]( store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema, TTL_VALUE_ROW_SCHEMA, RangeKeyScanStateEncoderSpec(keySchema, - Seq(0)), isInternal = true) + Seq(0)), isInternal = true, avroEncoderSpec = avroEnc) def clearTTLState(): Unit = { val iterator = store.iterator(ttlColumnFamilyName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index d0fbaf6600609..74eaf062ec547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -43,6 +43,15 @@ object TimerStateUtils { TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF } } + + def getSecIndexColFamilyName(timeMode: String): String = { + assert(timeMode == TimeMode.EventTime.toString || timeMode == TimeMode.ProcessingTime.toString) + if (timeMode == TimeMode.EventTime.toString) { + TimerStateUtils.EVENT_TIMERS_STATE_NAME + TimerStateUtils.TIMESTAMP_TO_KEY_CF + } else { + TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.TIMESTAMP_TO_KEY_CF + } + } } /** @@ -55,7 +64,9 @@ object TimerStateUtils { class TimerStateImpl( store: StateStore, timeMode: TimeMode, - keyExprEnc: ExpressionEncoder[Any]) extends Logging { + keyExprEnc: ExpressionEncoder[Any], + avroEnc: Option[AvroEncoder] = None, + secIndexAvroEnc: Option[AvroEncoder] = None) extends Logging { private val EMPTY_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) @@ -75,7 +86,7 @@ class TimerStateImpl( private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow, schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), - useMultipleValuesPerKey = false, isInternal = true) + useMultipleValuesPerKey = false, isInternal = true, avroEncoderSpec = avroEnc) // We maintain a secondary index that inverts the ordering of the timestamp // and grouping key @@ -83,7 +94,7 @@ class TimerStateImpl( private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)), - useMultipleValuesPerKey = false, isInternal = true) + useMultipleValuesPerKey = false, isInternal = true, avroEncoderSpec = secIndexAvroEnc) private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 4f7a10f882453..adb7c27363ae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -75,7 +76,8 @@ case class TransformWithStateExec( initialStateGroupingAttrs: Seq[Attribute], initialStateDataAttrs: Seq[Attribute], initialStateDeserializer: Expression, - initialState: SparkPlan) + initialState: SparkPlan, + columnFamilySchemas: Map[String, StateStoreColFamilySchema] = Map.empty) extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { override def shortName: String = "transformWithStateExec" @@ -104,7 +106,10 @@ case class TransformWithStateExec( * @return a new instance of the driver processor handle */ private def getDriverProcessorHandle(): DriverStatefulProcessorHandleImpl = { - val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode, keyEncoder) + + val driverProcessorHandle = new DriverStatefulProcessorHandleImpl( + timeMode, keyEncoder, initializeAvroEnc = + stateStoreEncoding == StateStoreEncoding.Avro.toString) driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) statefulProcessor.init(outputMode, timeMode) @@ -115,7 +120,7 @@ case class TransformWithStateExec( * Fetching the columnFamilySchemas from the StatefulProcessorHandle * after init is called. */ - private def getColFamilySchemas(): Map[String, StateStoreColFamilySchema] = { + def getColFamilySchemas(): Map[String, StateStoreColFamilySchema] = { val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas closeProcessorHandle() columnFamilySchemas @@ -470,7 +475,8 @@ case class TransformWithStateExec( newSchemas.values.toList, session.sessionState, stateSchemaVersion, storeName = StateStoreId.DEFAULT_STORE_NAME, oldSchemaFilePath = oldStateSchemaFilePath, - newSchemaFilePath = Some(newStateSchemaFilePath))) + newSchemaFilePath = Some(newStateSchemaFilePath), + usingAvro = true)) } /** Metadata of this stateful operator and its states stores. */ @@ -521,7 +527,6 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - validateTimeMode() if (hasInitialState) { @@ -552,10 +557,13 @@ case class TransformWithStateExec( hadoopConf = hadoopConfBroadcast.value.value ) - processDataWithInitialState(store, childDataIterator, initStateIterator) + processDataWithInitialState( + store, childDataIterator, initStateIterator, columnFamilySchemas) } else { - initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast) { store => - processDataWithInitialState(store, childDataIterator, initStateIterator) + initNewStateStoreAndProcessData( + partitionId, hadoopConfBroadcast) { store => + processDataWithInitialState( + store, childDataIterator, initStateIterator, columnFamilySchemas) } } } @@ -571,7 +579,7 @@ case class TransformWithStateExec( useColumnFamilies = true ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - processData(store, singleIterator) + processData(store, singleIterator, columnFamilySchemas) } } else { // If the query is running in batch mode, we need to create a new StateStore and instantiate @@ -580,8 +588,9 @@ case class TransformWithStateExec( new SerializableConfiguration(session.sessionState.newHadoopConf())) child.execute().mapPartitionsWithIndex[InternalRow]( (i: Int, iter: Iterator[InternalRow]) => { - initNewStateStoreAndProcessData(i, hadoopConfBroadcast) { store => - processData(store, iter) + initNewStateStoreAndProcessData( + i, hadoopConfBroadcast) { store => + processData(store, iter, columnFamilySchemas) } } ) @@ -596,7 +605,8 @@ case class TransformWithStateExec( private def initNewStateStoreAndProcessData( partitionId: Int, hadoopConfBroadcast: Broadcast[SerializableConfiguration]) - (f: StateStore => CompletionIterator[InternalRow, Iterator[InternalRow]]): + (f: StateStore => + CompletionIterator[InternalRow, Iterator[InternalRow]]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val providerId = { @@ -621,7 +631,7 @@ case class TransformWithStateExec( hadoopConf = hadoopConfBroadcast.value.value, useMultipleValuesPerKey = true) - val store = stateStoreProvider.getStore(0, None) + val store = stateStoreProvider.getStore(0) val outputIterator = f(store) CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator.iterator, { stateStoreProvider.close() @@ -633,13 +643,17 @@ case class TransformWithStateExec( * Process the data in the partition using the state store and the stateful processor. * @param store The state store to use * @param singleIterator The iterator of rows to process + * @param schemas The column family schemas used by this stateful processor * @return An iterator of rows that are the result of processing the input rows */ - private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): + private def processData( + store: StateStore, + singleIterator: Iterator[InternalRow], + schemas: Map[String, StateStoreColFamilySchema]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics) + isStreaming, batchTimestampMs, metrics, schemas) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeMode) @@ -650,10 +664,11 @@ case class TransformWithStateExec( private def processDataWithInitialState( store: StateStore, childDataIterator: Iterator[InternalRow], - initStateIterator: Iterator[InternalRow]): + initStateIterator: Iterator[InternalRow], + schemas: Map[String, StateStoreColFamilySchema]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) + keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics, schemas) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeMode) @@ -689,7 +704,7 @@ case class TransformWithStateExec( } // scalastyle:off argcount -object TransformWithStateExec { +object TransformWithStateExec extends Logging { // Plan logical transformWithState for batch queries def generateSparkPlanForBatchQueries( @@ -718,6 +733,22 @@ object TransformWithStateExec { stateStoreCkptIds = None ) + val stateStoreEncoding = child.session.sessionState.conf.getConf( + SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT + ) + + def getDriverProcessorHandle(): DriverStatefulProcessorHandleImpl = { + val driverProcessorHandle = new DriverStatefulProcessorHandleImpl( + timeMode, keyEncoder, initializeAvroEnc = + stateStoreEncoding == StateStoreEncoding.Avro.toString) + driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) + statefulProcessor.setHandle(driverProcessorHandle) + statefulProcessor.init(outputMode, timeMode) + driverProcessorHandle + } + + val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas + new TransformWithStateExec( keyDeserializer, valueDeserializer, @@ -738,7 +769,8 @@ object TransformWithStateExec { initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, - initialState) + initialState, + columnFamilySchemas) } } // scalastyle:on argcount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index cd66bf99d4e15..9eb51abaee6e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState /** @@ -30,6 +30,8 @@ import org.apache.spark.sql.streaming.ValueState * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value * @param metrics - metrics to be updated as part of stateful processing + * @param avroEnc - optional Avro serializer and deserializer for this state variable that + * is used by the StateStore to encode state in Avro format * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( @@ -37,7 +39,8 @@ class ValueStateImpl[S]( stateName: String, keyExprEnc: ExpressionEncoder[Any], valEncoder: ExpressionEncoder[Any], - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + avroEnc: Option[AvroEncoder] = None) extends ValueState[S] with Logging { private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName) @@ -46,7 +49,7 @@ class ValueStateImpl[S]( private def initialize(): Unit = { store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, valEncoder.schema, - NoPrefixKeyStateEncoderSpec(keyExprEnc.schema)) + NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), avroEncoderSpec = avroEnc) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 60eea5842645e..7c2401dffb2f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} /** @@ -34,6 +34,10 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState} * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. * @param metrics - metrics to be updated as part of stateful processing + * @param avroEnc - optional Avro serializer and deserializer for this state variable that + * is used by the StateStore to encode state in Avro format + * @param secondaryIndexAvroEnc - optional Avro serializer and deserializer for TTL state that + * is used by the StateStore to encode state in Avro format * @tparam S - data type of object that will be stored */ class ValueStateImplWithTTL[S]( @@ -43,9 +47,11 @@ class ValueStateImplWithTTL[S]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + avroEnc: Option[AvroEncoder] = None, + secondaryIndexAvroEnc: Option[AvroEncoder] = None) extends SingleKeyTTLStateImpl( - stateName, store, keyExprEnc, batchTimestampMs) with ValueState[S] { + stateName, store, keyExprEnc, batchTimestampMs, secondaryIndexAvroEnc) with ValueState[S] { private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) @@ -57,7 +63,7 @@ class ValueStateImplWithTTL[S]( private def initialize(): Unit = { store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, getValueSchemaWithTTL(valEncoder.schema, true), - NoPrefixKeyStateEncoderSpec(keyExprEnc.schema)) + NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), avroEncoderSpec = avroEnc) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 2f77b2c14b009..423ce50776fa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -127,7 +127,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, useMultipleValuesPerKey: Boolean = false, - isInternal: Boolean = false): Unit = { + isInternal: Boolean = false, + avroEnc: Option[AvroEncoder]): Unit = { throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 4c7a226e0973f..c69cf6efc813b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -17,13 +17,21 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.ByteArrayOutputStream import java.lang.Double.{doubleToRawLongBits, longBitsToDouble} import java.lang.Float.{floatToRawIntBits, intBitsToFloat} import java.nio.{ByteBuffer, ByteOrder} +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.io.{DecoderFactory, EncoderFactory} + import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -49,6 +57,7 @@ abstract class RocksDBKeyStateEncoderBase( def offsetForColFamilyPrefix: Int = if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0 + val out = new ByteArrayOutputStream /** * Get Byte Array for the virtual column family id that is used as prefix for * key state rows. @@ -89,23 +98,24 @@ abstract class RocksDBKeyStateEncoderBase( } } -object RocksDBStateEncoder { +object RocksDBStateEncoder extends Logging { def getKeyEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, useColumnFamilies: Boolean, - virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = { // Return the key state encoder based on the requested type keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(keySchema) => - new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId) + new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId, avroEnc) case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey, - useColumnFamilies, virtualColFamilyId) + useColumnFamilies, virtualColFamilyId, avroEnc) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => new RangeKeyScanStateEncoder(keySchema, orderingOrdinals, - useColumnFamilies, virtualColFamilyId) + useColumnFamilies, virtualColFamilyId, avroEnc) case _ => throw new IllegalArgumentException(s"Unsupported key state encoder spec: " + @@ -115,11 +125,12 @@ object RocksDBStateEncoder { def getValueEncoder( valueSchema: StructType, - useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = { + useMultipleValuesPerKey: Boolean, + avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = { if (useMultipleValuesPerKey) { - new MultiValuedStateEncoder(valueSchema) + new MultiValuedStateEncoder(valueSchema, avroEnc) } else { - new SingleValueStateEncoder(valueSchema) + new SingleValueStateEncoder(valueSchema, avroEnc) } } @@ -145,6 +156,26 @@ object RocksDBStateEncoder { encodedBytes } + /** + * This method takes an UnsafeRow, and serializes to a byte array using Avro encoding. + */ + def encodeUnsafeRowToAvro( + row: UnsafeRow, + avroSerializer: AvroSerializer, + valueAvroType: Schema, + out: ByteArrayOutputStream): Array[Byte] = { + // InternalRow -> Avro.GenericDataRecord + val avroData = + avroSerializer.serialize(row) + out.reset() + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + valueAvroType) // Defining Avro writer for this struct type + writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array + encoder.flush() + out.toByteArray + } + def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = { if (bytes != null) { val row = new UnsafeRow(numFields) @@ -154,6 +185,26 @@ object RocksDBStateEncoder { } } + /** + * This method takes a byte array written using Avro encoding, and + * deserializes to an UnsafeRow using the Avro deserializer + */ + def decodeFromAvroToUnsafeRow( + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + val reader = new GenericDatumReader[Any](valueAvroType) + val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0, valueBytes.length, null) + // bytes -> Avro.GenericDataRecord + val genericData = reader.read(null, decoder) + // Avro.GenericDataRecord -> InternalRow + val internalRow = avroDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] + // InternalRow -> UnsafeRow + valueProj.apply(internalRow) + } + def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { if (bytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. @@ -174,16 +225,20 @@ object RocksDBStateEncoder { * @param keySchema - schema of the key to be encoded * @param numColsPrefixKey - number of columns to be used for prefix key * @param useColumnFamilies - if column family is enabled for this encoder + * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will + * be defined */ class PrefixKeyScanStateEncoder( keySchema: StructType, numColsPrefixKey: Int, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { keySchema.zipWithIndex.take(numColsPrefixKey) } @@ -203,6 +258,18 @@ class PrefixKeyScanStateEncoder( UnsafeProjection.create(refs) } + // Prefix Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey)) + private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) + private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) + + // Remaining Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey)) + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema) + // This is quite simple to do - just bind sequentially, as we don't change the order. private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema) @@ -210,9 +277,24 @@ class PrefixKeyScanStateEncoder( private val joinedRowOnKey = new JoinedRow() override def encodeKey(row: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row)) - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) - + val (prefixKeyEncoded, remainingEncoded) = if (usingAvroEncoding) { + ( + encodeUnsafeRowToAvro( + extractPrefixKey(row), + avroEnc.get.keySerializer, + prefixKeyAvroType, + out + ), + encodeUnsafeRowToAvro( + remainingKeyProjection(row), + avroEnc.get.suffixKeySerializer.get, + remainingKeyAvroType, + out + ) + ) + } else { + (encodeUnsafeRow(extractPrefixKey(row)), encodeUnsafeRow(remainingKeyProjection(row))) + } val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + remainingEncoded.length + 4 ) @@ -243,9 +325,25 @@ class PrefixKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - numColsPrefixKey) + val (prefixKeyDecoded, remainingKeyDecoded) = if (usingAvroEncoding) { + ( + decodeFromAvroToUnsafeRow( + prefixKeyEncoded, + avroEnc.get.keyDeserializer, + prefixKeyAvroType, + prefixKeyProj + ), + decodeFromAvroToUnsafeRow( + remainingKeyEncoded, + avroEnc.get.suffixKeyDeserializer.get, + remainingKeyAvroType, + remainingKeyProj + ) + ) + } else { + (decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey), + decodeToUnsafeRow(remainingKeyEncoded, numFields = keySchema.length - numColsPrefixKey)) + } restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) } @@ -255,7 +353,11 @@ class PrefixKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(prefixKey) + val prefixKeyEncoded = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(prefixKey, avroEnc.get.keySerializer, prefixKeyAvroType, out) + } else { + encodeUnsafeRow(prefixKey) + } val (prefix, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + 4 ) @@ -299,13 +401,16 @@ class PrefixKeyScanStateEncoder( * @param keySchema - schema of the key to be encoded * @param orderingOrdinals - the ordinals for which the range scan is constructed * @param useColumnFamilies - if column family is enabled for this encoder + * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will + * be defined */ class RangeKeyScanStateEncoder( keySchema: StructType, orderingOrdinals: Seq[Int], useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { import RocksDBStateEncoder._ @@ -374,6 +479,22 @@ class RangeKeyScanStateEncoder( UnsafeProjection.create(refs) } + private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray)) + + private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema) + + private val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema) + + // Existing remainder key schema stuff + private val remainingKeySchema = StructType( + 0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_)) + ) + + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + + private val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) + // Reusable objects private val joinedRowOnKey = new JoinedRow() @@ -563,13 +684,242 @@ class RangeKeyScanStateEncoder( writer.getRow() } + def encodePrefixKeyForRangeScan( + row: UnsafeRow, + avroType: Schema + ): Array[Byte] = { + val record = new GenericData.Record(avroType) + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val value = row.get(idx, field.dataType) + if (value == null) { + record.put(fieldIdx, nullValMarker) // isNull marker + record.put(fieldIdx + 1, new Array[Byte](field.dataType.defaultSize)) + } else { + field.dataType match { + case BooleanType => + val boolVal = value.asInstanceOf[Boolean] + record.put(fieldIdx, positiveValMarker) // not null marker + record.put(fieldIdx + 1, ByteBuffer.wrap(Array[Byte](if (boolVal) 1 else 0))) + + case ByteType => + val byteVal = value.asInstanceOf[Byte] + val marker = positiveValMarker + record.put(fieldIdx, marker) + + val bytes = new Array[Byte](1) + bytes(0) = byteVal + record.put(fieldIdx + 1, ByteBuffer.wrap(bytes)) + + case ShortType => + val shortVal = value.asInstanceOf[Short] + val marker = if (shortVal >= 0) positiveValMarker else negativeValMarker + record.put(fieldIdx, marker) + + val bbuf = ByteBuffer.allocate(2) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.putShort(shortVal) + val bytes = new Array[Byte](2) + bbuf.position(0) + bbuf.get(bytes) + record.put(fieldIdx + 1, ByteBuffer.wrap(bytes)) + + case IntegerType => + val intVal = value.asInstanceOf[Int] + val marker = if (intVal >= 0) positiveValMarker else negativeValMarker + record.put(fieldIdx, marker) + + val bbuf = ByteBuffer.allocate(4) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.putInt(intVal) + val bytes = new Array[Byte](4) + bbuf.position(0) + bbuf.get(bytes) + record.put(fieldIdx + 1, ByteBuffer.wrap(bytes)) + + case LongType => + val longVal = value.asInstanceOf[Long] + val marker = if (longVal >= 0) positiveValMarker else negativeValMarker + record.put(fieldIdx, marker) + + val bbuf = ByteBuffer.allocate(8) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.putLong(longVal) + val bytes = new Array[Byte](8) + bbuf.position(0) + bbuf.get(bytes) + record.put(fieldIdx + 1, ByteBuffer.wrap(bytes)) + + case FloatType => + val floatVal = value.asInstanceOf[Float] + val rawBits = floatToRawIntBits(floatVal) + + val bbuf = ByteBuffer.allocate(4) + bbuf.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & floatSignBitMask) != 0) { + record.put(fieldIdx, negativeValMarker) + // For negative values, flip the bits to maintain proper ordering + val updatedVal = rawBits ^ floatFlipBitMask + bbuf.putFloat(intBitsToFloat(updatedVal)) + } else { + record.put(fieldIdx, positiveValMarker) + bbuf.putFloat(floatVal) + } + val bytes = new Array[Byte](4) + bbuf.position(0) + bbuf.get(bytes) + record.put(fieldIdx + 1, ByteBuffer.wrap(bytes)) + + case DoubleType => + val doubleVal = value.asInstanceOf[Double] + val rawBits = doubleToRawLongBits(doubleVal) + + val bbuf = ByteBuffer.allocate(8) + bbuf.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & doubleSignBitMask) != 0) { + // For negative values, flip the bits to maintain proper ordering + record.put(fieldIdx, negativeValMarker) + val updatedVal = rawBits ^ doubleFlipBitMask + bbuf.putDouble(longBitsToDouble(updatedVal)) + } else { + record.put(fieldIdx, positiveValMarker) + bbuf.putDouble(doubleVal) + } + val bytes = new Array[Byte](8) + bbuf.position(0) + bbuf.get(bytes) + record.put(fieldIdx + 1, ByteBuffer.wrap(bytes)) + + case _ => throw new UnsupportedOperationException( + s"Range scan encoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + out.reset() + val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType) + val encoder = EncoderFactory.get().binaryEncoder(out, null) + writer.write(record, encoder) + encoder.flush() + out.toByteArray + } + + def decodePrefixKeyForRangeScan( + bytes: Array[Byte], + avroType: Schema): UnsafeRow = { + + val reader = new GenericDatumReader[GenericRecord](avroType) + val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) + val record = reader.read(null, decoder) + + val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length) + rowWriter.resetRowWriter() + + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val isMarkerNull = record.get(fieldIdx) == nullValMarker + + if (isMarkerNull) { + rowWriter.setNullAt(idx) + } else { + field.dataType match { + case BooleanType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + rowWriter.write(idx, bytes(0) == 1) + + case ByteType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + rowWriter.write(idx, bytes(0)) + + case ShortType => + val byteBuf = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer] + val bbuf = ByteBuffer.allocate(2) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.put(byteBuf.array(), byteBuf.position(), byteBuf.remaining()) + bbuf.flip() + rowWriter.write(idx, bbuf.getShort) + + case IntegerType => + val byteBuf = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer] + val bbuf = ByteBuffer.allocate(4) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.put(byteBuf.array(), byteBuf.position(), byteBuf.remaining()) + bbuf.flip() + rowWriter.write(idx, bbuf.getInt) + + case LongType => + val byteBuf = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer] + val bbuf = ByteBuffer.allocate(8) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.put(byteBuf.array(), byteBuf.position(), byteBuf.remaining()) + bbuf.flip() + rowWriter.write(idx, bbuf.getLong) + + case FloatType => + val byteBuf = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer] + val bbuf = ByteBuffer.allocate(4) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.put(byteBuf.array(), byteBuf.position(), byteBuf.remaining()) + bbuf.flip() + + val isNegative = record.get(fieldIdx).asInstanceOf[Byte] == negativeValMarker + if (isNegative) { + val floatVal = bbuf.getFloat + val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask + rowWriter.write(idx, intBitsToFloat(updatedVal)) + } else { + rowWriter.write(idx, bbuf.getFloat) + } + + case DoubleType => + val byteBuf = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer] + val bbuf = ByteBuffer.allocate(8) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.put(byteBuf.array(), byteBuf.position(), byteBuf.remaining()) + bbuf.flip() + + val isNegative = record.get(fieldIdx).asInstanceOf[Byte] == negativeValMarker + if (isNegative) { + val doubleVal = bbuf.getDouble + val updatedVal = doubleToRawLongBits(doubleVal) ^ doubleFlipBitMask + rowWriter.write(idx, longBitsToDouble(updatedVal)) + } else { + rowWriter.write(idx, bbuf.getDouble) + } + + case _ => throw new UnsupportedOperationException( + s"Range scan decoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + rowWriter.getRow() + } + override def encodeKey(row: UnsafeRow): Array[Byte] = { // This prefix key has the columns specified by orderingOrdinals val prefixKey = extractPrefixKey(row) - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = if (avroEnc.isDefined) { + encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType) + } else { + encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + } val result = if (orderingOrdinals.length < keySchema.length) { - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) + val remainingEncoded = if (avroEnc.isDefined) { + encodeUnsafeRowToAvro( + remainingKeyProjection(row), + avroEnc.get.suffixKeySerializer.get, + remainingKeyAvroType, + out + ) + } else { + encodeUnsafeRow(remainingKeyProjection(row)) + } val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( rangeScanKeyEncoded.length + remainingEncoded.length + 4 ) @@ -606,9 +956,12 @@ class RangeKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4, prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) - val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded, - numFields = orderingOrdinals.length) - val prefixKeyDecoded = decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan) + val prefixKeyDecoded = if (avroEnc.isDefined) { + decodePrefixKeyForRangeScan(prefixKeyEncoded, rangeScanAvroType) + } else { + decodePrefixKeyForRangeScan(decodeToUnsafeRow(prefixKeyEncoded, + numFields = orderingOrdinals.length)) + } if (orderingOrdinals.length < keySchema.length) { // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes @@ -620,8 +973,14 @@ class RangeKeyScanStateEncoder( remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - orderingOrdinals.length) + val remainingKeyDecoded = if (avroEnc.isDefined) { + decodeFromAvroToUnsafeRow(remainingKeyEncoded, + avroEnc.get.suffixKeyDeserializer.get, + remainingKeyAvroType, remainingKeyAvroProjection) + } else { + decodeToUnsafeRow(remainingKeyEncoded, + numFields = keySchema.length - orderingOrdinals.length) + } val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded) val restored = restoreKeyProjection(joined) @@ -634,7 +993,11 @@ class RangeKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = if (avroEnc.isDefined) { + encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType) + } else { + encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + } val (prefix, startingOffset) = encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4) Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length) @@ -653,6 +1016,7 @@ class RangeKeyScanStateEncoder( * It uses the first byte of the generated byte array to store the version the describes how the * row is encoded in the rest of the byte array. Currently, the default version is 0, * + * If the avroEnc is specified, we are using Avro encoding for this column family's keys * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ] * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, @@ -661,19 +1025,27 @@ class RangeKeyScanStateEncoder( class NoPrefixKeyStateEncoder( keySchema: StructType, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { import RocksDBStateEncoder._ // Reusable objects + private val usingAvroEncoding = avroEnc.isDefined private val keyRow = new UnsafeRow(keySchema.size) + private lazy val keyAvroType = SchemaConverters.toAvroType(keySchema) + private val keyProj = UnsafeProjection.create(keySchema) override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { encodeUnsafeRow(row) } else { - val bytesToEncode = row.getBytes + // If avroEnc is defined, we know that we need to use Avro to + // encode this UnsafeRow to Avro bytes + val bytesToEncode = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.keySerializer, keyAvroType, out) + } else row.getBytes val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES @@ -697,11 +1069,25 @@ class NoPrefixKeyStateEncoder( if (useColumnFamilies) { if (keyBytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. - keyRow.pointTo( - keyBytes, - decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, - keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) - keyRow + if (usingAvroEncoding) { + val avroBytes = new Array[Byte]( + keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) + System.arraycopy( + keyBytes, + decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + avroBytes, + 0, + avroBytes.length + ) + decodeFromAvroToUnsafeRow( + keyBytes, avroEnc.get.keyDeserializer, keyAvroType, keyProj) + } else { + keyRow.pointTo( + keyBytes, + decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) + keyRow + } } else { null } @@ -727,17 +1113,28 @@ class NoPrefixKeyStateEncoder( * This encoder supports RocksDB StringAppendOperator merge operator. Values encoded can be * merged in RocksDB using merge operation, and all merged values can be read using decodeValues * operation. + * If the avroEnc is specified, we are using Avro encoding for this column family's values */ -class MultiValuedStateEncoder(valueSchema: StructType) +class MultiValuedStateEncoder( + valueSchema: StructType, + avroEnc: Option[AvroEncoder] = None) extends RocksDBValueStateEncoder with Logging { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined // Reusable objects + private val out = new ByteArrayOutputStream private val valueRow = new UnsafeRow(valueSchema.size) + private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema) + private val valueProj = UnsafeProjection.create(valueSchema) override def encodeValue(row: UnsafeRow): Array[Byte] = { - val bytes = encodeUnsafeRow(row) + val bytes = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) + } else { + encodeUnsafeRow(row) + } val numBytes = bytes.length val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length) @@ -756,7 +1153,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) val encodedValue = new Array[Byte](numBytes) Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET, encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes) - decodeToUnsafeRow(encodedValue, valueRow) + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + encodedValue, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(encodedValue, valueRow) + } } } @@ -782,7 +1184,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) pos += numBytes pos += 1 // eat the delimiter character - decodeToUnsafeRow(encodedValue, valueRow) + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + encodedValue, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(encodedValue, valueRow) + } } } } @@ -802,16 +1209,29 @@ class MultiValuedStateEncoder(valueSchema: StructType) * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, * then the generated array byte will be N+1 bytes. + * If the avroEnc is specified, we are using Avro encoding for this column family's values */ -class SingleValueStateEncoder(valueSchema: StructType) - extends RocksDBValueStateEncoder { +class SingleValueStateEncoder( + valueSchema: StructType, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBValueStateEncoder with Logging { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined // Reusable objects + private val out = new ByteArrayOutputStream private val valueRow = new UnsafeRow(valueSchema.size) + private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema) + private val valueProj = UnsafeProjection.create(valueSchema) - override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) + override def encodeValue(row: UnsafeRow): Array[Byte] = { + if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) + } else { + encodeUnsafeRow(row) + } + } /** * Decode byte array for a value to a UnsafeRow. @@ -820,7 +1240,15 @@ class SingleValueStateEncoder(valueSchema: StructType) * the given byte array. */ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { - decodeToUnsafeRow(valueBytes, valueRow) + if (valueBytes == null) { + return null + } + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(valueBytes, valueRow) + } } override def supportsMultipleValuesPerKey: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1fc6ab5910c6c..146c983be3170 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -71,13 +71,14 @@ private[sql] class RocksDBStateStoreProvider valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, useMultipleValuesPerKey: Boolean = false, - isInternal: Boolean = false): Unit = { + isInternal: Boolean = false, + avroEnc: Option[AvroEncoder]): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) keyValueEncoderMap.putIfAbsent(colFamilyName, (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies, - Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(valueSchema, - useMultipleValuesPerKey))) + Some(newColFamilyId), avroEnc), RocksDBStateEncoder.getValueEncoder(valueSchema, + useMultipleValuesPerKey, avroEnc))) } override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 721d72b6a0991..5bb511d5d5567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.execution.streaming.state +import scala.jdk.CollectionConverters.IterableHasAsJava import scala.util.Try +import org.apache.avro.SchemaValidatorBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} @@ -37,14 +40,26 @@ case class StateSchemaValidationResult( schemaPath: String ) +// Avro encoder that is used by the RocksDBStateStoreProvider and RocksDBStateEncoder +// in order to serialize from UnsafeRow to a byte array of Avro encoding. +case class AvroEncoder( + keySerializer: AvroSerializer, + keyDeserializer: AvroDeserializer, + valueSerializer: AvroSerializer, + valueDeserializer: AvroDeserializer, + suffixKeySerializer: Option[AvroSerializer] = None, + suffixKeyDeserializer: Option[AvroDeserializer] = None +) extends Serializable + // Used to represent the schema of a column family in the state store case class StateStoreColFamilySchema( colFamilyName: String, keySchema: StructType, valueSchema: StructType, keyStateEncoderSpec: Option[KeyStateEncoderSpec] = None, - userKeyEncoderSchema: Option[StructType] = None -) + userKeyEncoderSchema: Option[StructType] = None, + avroEnc: Option[AvroEncoder] = None +) extends Serializable class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, @@ -138,7 +153,8 @@ class StateSchemaCompatibilityChecker( private def check( oldSchema: StateStoreColFamilySchema, newSchema: StateStoreColFamilySchema, - ignoreValueSchema: Boolean) : Unit = { + ignoreValueSchema: Boolean, + usingAvro: Boolean) : Boolean = { val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema, oldSchema.valueSchema) val (keySchema, valueSchema) = (newSchema.keySchema, newSchema.valueSchema) @@ -146,14 +162,27 @@ class StateSchemaCompatibilityChecker( if (storedKeySchema.equals(keySchema) && (ignoreValueSchema || storedValueSchema.equals(valueSchema))) { // schema is exactly same + false } else if (!schemasCompatible(storedKeySchema, keySchema)) { throw StateStoreErrors.stateStoreKeySchemaNotCompatible(storedKeySchema.toString, keySchema.toString) + } else if (!ignoreValueSchema && usingAvro) { + // By this point, we know that old value schema is not equal to new value schema + val oldAvroSchema = SchemaConverters.toAvroType(storedValueSchema) + val newAvroSchema = SchemaConverters.toAvroType(valueSchema) + val validator = new SchemaValidatorBuilder().canReadStrategy.validateAll() + // This will throw a SchemaValidation exception if the schema has evolved in an + // unacceptable way. + validator.validate(newAvroSchema, Iterable(oldAvroSchema).asJava) + // If no exception is thrown, then we know that the schema evolved in an + // acceptable way + true } else if (!ignoreValueSchema && !schemasCompatible(storedValueSchema, valueSchema)) { throw StateStoreErrors.stateStoreValueSchemaNotCompatible(storedValueSchema.toString, valueSchema.toString) } else { logInfo("Detected schema change which is compatible. Allowing to put rows.") + true } } @@ -167,7 +196,8 @@ class StateSchemaCompatibilityChecker( def validateAndMaybeEvolveStateSchema( newStateSchema: List[StateStoreColFamilySchema], ignoreValueSchema: Boolean, - stateSchemaVersion: Int): Boolean = { + stateSchemaVersion: Int, + usingAvro: Boolean): Boolean = { val existingStateSchemaList = getExistingKeyAndValueSchema() val newStateSchemaList = newStateSchema @@ -182,18 +212,18 @@ class StateSchemaCompatibilityChecker( }.toMap // For each new state variable, we want to compare it to the old state variable // schema with the same name - newStateSchemaList.foreach { newSchema => - existingSchemaMap.get(newSchema.colFamilyName).foreach { existingStateSchema => - check(existingStateSchema, newSchema, ignoreValueSchema) - } + val hasEvolvedSchema = newStateSchemaList.exists { newSchema => + existingSchemaMap.get(newSchema.colFamilyName) + .exists(existingSchema => check(existingSchema, newSchema, ignoreValueSchema, usingAvro)) } val colFamiliesAddedOrRemoved = (newStateSchemaList.map(_.colFamilyName).toSet != existingSchemaMap.keySet) - if (stateSchemaVersion == SCHEMA_FORMAT_V3 && colFamiliesAddedOrRemoved) { + val newSchemaFileWritten = hasEvolvedSchema || colFamiliesAddedOrRemoved + if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFileWritten) { createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion) } // TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3 - colFamiliesAddedOrRemoved + newSchemaFileWritten } } @@ -242,7 +272,8 @@ object StateSchemaCompatibilityChecker { extraOptions: Map[String, String] = Map.empty, storeName: String = StateStoreId.DEFAULT_STORE_NAME, oldSchemaFilePath: Option[Path] = None, - newSchemaFilePath: Option[Path] = None): StateSchemaValidationResult = { + newSchemaFilePath: Option[Path] = None, + usingAvro: Boolean = false): StateSchemaValidationResult = { // SPARK-47776: collation introduces the concept of binary (in)equality, which means // in some collation we no longer be able to just compare the binary format of two // UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically" @@ -273,7 +304,7 @@ object StateSchemaCompatibilityChecker { val result = Try( checker.validateAndMaybeEvolveStateSchema(newStateSchema, ignoreValueSchema = !storeConf.formatValidationCheckValue, - stateSchemaVersion = stateSchemaVersion) + stateSchemaVersion = stateSchemaVersion, usingAvro) ).toEither.fold(Some(_), hasEvolvedSchema => { evolvedSchema = hasEvolvedSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 72bc3ca33054d..50843b1aeb438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -41,6 +41,13 @@ import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} +sealed trait StateStoreEncoding + +object StateStoreEncoding { + case object UnsafeRow extends StateStoreEncoding + case object Avro extends StateStoreEncoding +} + /** * Base trait for a versioned key-value store which provides read operations. Each instance of a * `ReadStateStore` represents a specific version of state data, and such instances are created @@ -126,7 +133,8 @@ trait StateStore extends ReadStateStore { /** * Create column family with given name, if absent. - * + * If Avro encoding is enabled for this query, we expect the avroEncoderSpec to + * be defined so that the Key and Value StateEncoders will use this. * @return column family ID */ def createColFamilyIfAbsent( @@ -135,7 +143,8 @@ trait StateStore extends ReadStateStore { valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, useMultipleValuesPerKey: Boolean = false, - isInternal: Boolean = false): Unit + isInternal: Boolean = false, + avroEncoderSpec: Option[AvroEncoder] = None): Unit /** * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 8f800b9f0252c..02048ee7ce682 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -111,6 +111,10 @@ trait StatefulOperator extends SparkPlan { } } + lazy val stateStoreEncoding: String = + conf.getConf( + SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT) + def metadataFilePath(): Path = { val stateCheckpointPath = new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 9a04a0c759ac4..9a982a2264701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -36,7 +36,8 @@ class MemoryStateStore extends StateStore() { valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, useMultipleValuesPerKey: Boolean = false, - isInternal: Boolean = false): Unit = { + isInternal: Boolean = false, + avroEnc: Option[AvroEncoder]): Unit = { throw StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 9ac74eb5b9e8f..346bfd37798f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -91,7 +91,8 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, useMultipleValuesPerKey: Boolean = false, - isInternal: Boolean = false): Unit = { + isInternal: Boolean = false, + avroEnc: Option[AvroEncoder]): Unit = { innerStore.createColFamilyIfAbsent( colFamilyName, keySchema, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e1bd9dd38066b..c7a895d165037 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.LocalSparkSession.withSparkSession import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StateStoreColumnFamilySchemaUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -339,6 +339,82 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + test("rocksdb range scan - fixed size non-ordering columns with Avro encoding") { + + + val keySchemaWithLong: StructType = StructType( + Seq(StructField("key1", StringType, false), StructField("key2", LongType, false), + StructField("key3", StringType, false), StructField("key4", LongType, false))) + + val remainingKeySchema: StructType = StructType( + Seq(StructField("key1", StringType, false), StructField("key3", StringType, false))) + tryWithProviderResource(newStoreProvider(keySchemaWithLong, + RangeKeyScanStateEncoderSpec(keySchemaWithLong, Seq(1, 3)), + useColumnFamilies = true)) { provider => + val store = provider.getStore(0) + + // use non-default col family if column families are enabled + val cfName = "testColFamily" + val convertedKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + keySchemaWithLong) + val avroSerde = StateStoreColumnFamilySchemaUtils(true).getAvroSerde( + keySchemaWithLong, + valueSchema, + Some(remainingKeySchema) + ) + store.createColFamilyIfAbsent(cfName, + keySchemaWithLong, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithLong, Seq(1, 3)), + avroEncoderSpec = avroSerde) + + val timerTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 8L, + -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L) + val otherLongs = Seq(3L, 2L, 1L) + + // Create all combinations using flatMap + val testPairs = timerTimestamps.flatMap { ts1 => + timerTimestamps.map { ts2 => + (ts1, ts2) + } + } + + testPairs.foreach { ts => + // non-timestamp col is of fixed size + val keyRow = dataToKeyRowWithRangeScan("a", ts._1, ts._2) + val valueRow = dataToValueRow(1) + store.put(keyRow, valueRow, cfName) + assert(valueRowToData(store.get(keyRow, cfName)) === 1) + } + + val result = store.iterator(cfName).map { kv => + (kv.key.getLong(1), kv.key.getLong(3)) + }.toSeq + assert(result === testPairs.sortBy(pair => (pair._1, pair._2))) + store.commit() + + // test with a different set of power of 2 timestamps + val store1 = provider.getStore(1) + val timerTimestamps1 = Seq(-32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) + val testPairs1 = timerTimestamps1.flatMap { ts1 => + otherLongs.map { ts2 => + (ts1, ts2) + } + } + testPairs1.foreach { ts => + // non-timestamp col is of fixed size + val keyRow = dataToKeyRowWithRangeScan("a", ts._1, ts._2) + val valueRow = dataToValueRow(1) + store1.put(keyRow, valueRow, cfName) + assert(valueRowToData(store1.get(keyRow, cfName)) === 1) + } + + val result1 = store1.iterator(cfName).map { kv => + (kv.key.getLong(1), kv.key.getLong(3)) + }.toSeq + assert(result1 === (testPairs ++ testPairs1).sortBy(pair => (pair._1, pair._2))) + } + } + testWithColumnFamilies("rocksdb range scan - variable size non-ordering columns with " + "double type values are supported", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 637eb49130305..e8b884dd2d1c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -128,6 +128,23 @@ trait AlsoTestWithChangelogCheckpointingEnabled } } + def testWithEncodingTypes(testName: String, testTags: Tag*) + (testBody: => Any): Unit = { + Seq("UnsafeRow", "Avro").foreach { encoding => + super.test(testName + s" (encoding = $encoding)", testTags: _*) { + // in case tests have any code that needs to execute before every test + super.beforeEach() + withSQLConf( + SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> + encoding) { + testBody + } + // in case tests have any code that needs to execute after every test + super.afterEach() + } + } + } + def testWithColumnFamilies( testName: String, testMode: TestMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 47dd77f1bb9fd..90d8b157d94fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1847,6 +1847,12 @@ object StateStoreTestsHelper { rangeScanProj.apply(new GenericInternalRow(Array[Any](ts, UTF8String.fromString(s)))).copy() } + def dataToKeyRowWithRangeScan(s: String, ts: Long, otherLong: Long): UnsafeRow = { + UnsafeProjection.create(Array[DataType](StringType, LongType, StringType, LongType)) + .apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s), ts, + UTF8String.fromString(s), otherLong))).copy() + } + def dataToValueRow(i: Int): UnsafeRow = { valueProj.apply(new GenericInternalRow(Array[Any](i))).copy() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 88862e2ad0791..9f26de126e0dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -130,7 +130,7 @@ class TransformWithListStateSuite extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ - test("test appending null value in list state throw exception") { + testWithEncodingTypes("test appending null value in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -150,7 +150,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test putting null value in list state throw exception") { + testWithEncodingTypes("test putting null value in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -170,7 +170,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test putting null list in list state throw exception") { + testWithEncodingTypes("test putting null list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -190,7 +190,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test appending null list in list state throw exception") { + testWithEncodingTypes("test appending null list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -210,7 +210,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test putting empty list in list state throw exception") { + testWithEncodingTypes("test putting empty list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -230,7 +230,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test appending empty list in list state throw exception") { + testWithEncodingTypes("test appending empty list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -250,7 +250,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test list state correctness") { + testWithEncodingTypes("test list state correctness") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -307,7 +307,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test ValueState And ListState in Processor") { + testWithEncodingTypes("test ValueState And ListState in Processor") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index 409a255ae3e64..ebd29bff5d354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -105,7 +105,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numListStateWithTTLVars" - test("verify iterator works with expired values in beginning of list") { + testWithEncodingTypes("verify iterator works with expired values in beginning of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -195,7 +195,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { // ascending order of TTL by stopping the query, setting the new TTL, and restarting // the query to check that the expired elements in the middle or end of the list // are not returned. - test("verify iterator works with expired values in middle of list") { + testWithEncodingTypes("verify iterator works with expired values in middle of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -343,7 +343,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify iterator works with expired values in end of list") { + testWithEncodingTypes("verify iterator works with expired values in end of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 76c5cbeee424b..63609ba96625c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -110,7 +110,7 @@ class TransformWithMapStateSuite extends StreamTest } } - test("Test retrieving value with non-existing user key") { + testWithEncodingTypes("Test retrieving value with non-existing user key") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -134,7 +134,7 @@ class TransformWithMapStateSuite extends StreamTest } } - test("Test put value with null value") { + testWithEncodingTypes("Test put value with null value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -158,7 +158,7 @@ class TransformWithMapStateSuite extends StreamTest } } - test("Test map state correctness") { + testWithEncodingTypes("Test map state correctness") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputData = MemoryStream[InputMapRow] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index 022280eb3bcef..a68632534c001 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -182,7 +182,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numMapStateWithTTLVars" - test("validate state is evicted with multiple user keys") { + testWithEncodingTypes("validate state is evicted with multiple user keys") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -224,7 +224,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify iterator doesn't return expired keys") { + testWithEncodingTypes("verify iterator doesn't return expired keys") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 3ef5c57ee3d07..457cf0d3c0cd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -409,7 +409,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest import testImplicits._ - test("transformWithState - streaming with rocksdb and invalid processor should fail") { + testWithEncodingTypes("transformWithState - streaming with rocksdb " + + "and invalid processor should fail") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -430,10 +431,9 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - lazy iterators can properly get/set keyed state") { + testWithEncodingTypes("transformWithState - lazy iterators can properly get/set keyed state") { val spark = this.spark import spark.implicits._ - class ProcessorWithLazyIterators extends StatefulProcessor[Long, Long, Long] { @transient protected var _myValueState: ValueState[Long] = _ @@ -508,7 +508,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb should succeed") { + testWithEncodingTypes("transformWithState - streaming with rocksdb should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -546,7 +546,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and processing time timer " + + testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + "should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -591,7 +591,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and processing time timer " + + testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + "and updating timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -627,7 +627,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and processing time timer " + + testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + "and multiple timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -664,7 +664,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and event time based timer") { + testWithEncodingTypes("transformWithState - streaming with rocksdb and event time based timer") { val inputData = MemoryStream[(String, Int)] val result = inputData.toDS() @@ -708,7 +708,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - test("Use statefulProcessor without transformWithState - handle should be absent") { + testWithEncodingTypes("Use statefulProcessor without transformWithState - " + + "handle should be absent") { val processor = new RunningCountStatefulProcessor() val ex = intercept[Exception] { processor.getHandle @@ -720,7 +721,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - test("transformWithState - batch should succeed") { + testWithEncodingTypes("transformWithState - batch should succeed") { val inputData = Seq("a", "b") val result = inputData.toDS() .groupByKey(x => x) @@ -732,7 +733,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) } - test("transformWithState - test deleteIfExists operator") { + testWithEncodingTypes("transformWithState - test deleteIfExists operator") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -773,7 +774,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - two input streams") { + testWithEncodingTypes("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -803,7 +804,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - three input streams") { + testWithEncodingTypes("transformWithState - three input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -838,7 +839,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - two input streams, different key type") { + testWithEncodingTypes("transformWithState - two input streams, different key type") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -885,7 +886,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest OutputMode.Update()) } - test("transformWithState - availableNow trigger mode, rate limit is respected") { + testWithEncodingTypes("transformWithState - availableNow trigger mode, rate limit is respected") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -926,7 +927,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - availableNow trigger mode, multiple restarts") { + testWithEncodingTypes("transformWithState - availableNow trigger mode, multiple restarts") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -964,7 +965,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { + testWithEncodingTypes("transformWithState - verify StateSchemaV3 writes " + + "correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1046,7 +1048,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify that OperatorStateMetadataV2" + + testWithEncodingTypes("transformWithState - verify that OperatorStateMetadataV2" + " file is being written correctly") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -1090,7 +1092,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that invalid schema evolution fails query for column family") { + testWithEncodingTypes("test that invalid schema evolution fails query for column family") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1170,7 +1172,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that changing between different state variable types fails") { + testWithEncodingTypes("test that changing between different state " + + "variable types fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1359,7 +1362,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test query restart succeeds") { + testWithEncodingTypes("test query restart succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1444,7 +1447,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest new Path(stateCheckpointPath, "_stateSchema/default/") } - test("transformWithState - verify that metadata and schema logs are purged") { + testWithEncodingTypes("transformWithState - verify that metadata and schema logs are purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 2ddf69aa49e04..e3b0a6b811742 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,14 +41,14 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest { + extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, InputEvent, OutputEvent] def getStateTTLMetricName: String - test("validate state is evicted at ttl expiry") { + testWithEncodingTypes("validate state is evicted at ttl expiry") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { dir => @@ -125,7 +125,7 @@ abstract class TransformWithStateTTLTest } } - test("validate state update updates the expiration timestamp") { + testWithEncodingTypes("validate state update updates the expiration timestamp") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputStream = MemoryStream[InputEvent] @@ -187,7 +187,7 @@ abstract class TransformWithStateTTLTest } } - test("validate state is evicted at ttl expiry for no data batch") { + testWithEncodingTypes("validate state is evicted at ttl expiry for no data batch") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputStream = MemoryStream[InputEvent] @@ -238,7 +238,7 @@ abstract class TransformWithStateTTLTest } } - test("validate only expired keys are removed from the state") { + testWithEncodingTypes("validate only expired keys are removed from the state") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 21c3beb79314c..91e88eec06cd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, TimerStateUtils, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -195,7 +195,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numValueStateWithTTLVars" - test("validate multiple value states") { + testWithEncodingTypes("validate multiple value states") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val ttlKey = "k1" @@ -262,7 +262,8 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") { + testWithEncodingTypes( + "verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -275,60 +276,98 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) val keySchema = new StructType().add("value", StringType) - val schemaForKeyRow: StructType = new StructType() - .add("key", new StructType(keySchema.fields)) + val timerKeyStruct = new StructType(keySchema.fields) + val schemaForTimerKeyRow: StructType = new StructType() + .add("key", timerKeyStruct) .add("expiryTimestampMs", LongType, nullable = false) - val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) + val schemaForTimerValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + + // Timer schemas val schema0 = StateStoreColFamilySchema( - TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString), - schemaForKeyRow, - schemaForValueRow, - Some(PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1))) + "$procTimers_keyToTimestamp", + schemaForTimerKeyRow, + schemaForTimerValueRow, + Some(PrefixKeyScanStateEncoderSpec(schemaForTimerKeyRow, 1))) + + val schemaForTimerReverseKeyRow: StructType = new StructType() + .add("expiryTimestampMs", LongType, nullable = false) + .add("key", timerKeyStruct) val schema1 = StateStoreColFamilySchema( - "valueStateTTL", - keySchema, - new StructType().add("value", - new StructType() - .add("value", IntegerType, false)) - .add("ttlExpirationMs", LongType), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) + "$procTimers_timestampToKey", + schemaForTimerReverseKeyRow, + schemaForTimerValueRow, + Some(RangeKeyScanStateEncoderSpec(schemaForTimerReverseKeyRow, List(0)))) + + // TTL tracking schemas + val ttlKeySchema = new StructType() + .add("expirationMs", BinaryType) + .add("groupingKey", keySchema) + val schema2 = StateStoreColFamilySchema( - "valueState", - keySchema, - new StructType().add("value", IntegerType, false), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) + "$ttl_listState", + ttlKeySchema, + schemaForTimerValueRow, + Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, List(0)))) + + val userKeySchema = new StructType() + .add("id", IntegerType, false) + .add("name", StringType) + val ttlMapKeySchema = new StructType() + .add("expirationMs", BinaryType) + .add("groupingKey", keySchema) + .add("userKey", userKeySchema) + val schema3 = StateStoreColFamilySchema( + "$ttl_mapState", + ttlMapKeySchema, + schemaForTimerValueRow, + Some(RangeKeyScanStateEncoderSpec(ttlMapKeySchema, List(0)))) + + val schema4 = StateStoreColFamilySchema( + "$ttl_valueStateTTL", + ttlKeySchema, + schemaForTimerValueRow, + Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, List(0)))) + + // Main state schemas + val schema5 = StateStoreColFamilySchema( "listState", keySchema, - new StructType().add("value", - new StructType() + new StructType() + .add("value", new StructType() .add("id", LongType, false) .add("name", StringType)) .add("ttlExpirationMs", LongType), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) + Some(NoPrefixKeyStateEncoderSpec(keySchema))) - val userKeySchema = new StructType() - .add("id", IntegerType, false) - .add("name", StringType) val compositeKeySchema = new StructType() .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) - val schema4 = StateStoreColFamilySchema( + val schema6 = StateStoreColFamilySchema( "mapState", compositeKeySchema, - new StructType().add("value", - new StructType() + new StructType() + .add("value", new StructType() .add("value", StringType)) .add("ttlExpirationMs", LongType), Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), - Option(userKeySchema) - ) + Option(userKeySchema)) + + val schema7 = StateStoreColFamilySchema( + "valueState", + keySchema, + new StructType().add("value", IntegerType, false), + Some(NoPrefixKeyStateEncoderSpec(keySchema))) + + val schema8 = StateStoreColFamilySchema( + "valueStateTTL", + keySchema, + new StructType() + .add("value", new StructType() + .add("value", IntegerType, false)) + .add("ttlExpirationMs", LongType), + Some(NoPrefixKeyStateEncoderSpec(keySchema))) val ttlKey = "k1" val noTtlKey = "k2" @@ -370,9 +409,11 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { q.lastProgress.stateOperators.head.customMetrics .get("numMapStateWithTTLVars").toInt) - assert(colFamilySeq.length == 5) + // Now expect 9 column families + assert(colFamilySeq.length == 9) assert(colFamilySeq.map(_.toString).toSet == Set( - schema0, schema1, schema2, schema3, schema4 + schema0, schema1, schema2, schema3, schema4, + schema5, schema6, schema7, schema8 ).map(_.toString)) }, StopStream