diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 123759c6c8b80..f2c1d11a226fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2230,6 +2230,17 @@ object SQLConf { .intConf .createWithDefault(1) + val STREAMING_STATE_STORE_ENCODING_FORMAT = + buildConf("spark.sql.streaming.stateStore.encodingFormat") + .doc("The encoding format used for stateful operators to store information " + + "in the state store") + .version("4.0.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValue(v => Set("unsaferow", "avro").contains(v), + "Valid values are 'unsaferow' and 'avro'") + .createWithDefault("unsaferow") + val STATE_STORE_COMPRESSION_CODEC = buildConf("spark.sql.streaming.stateStore.compression.codec") .internal() @@ -5607,6 +5618,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreCheckpointFormatVersion: Int = getConf(STATE_STORE_CHECKPOINT_FORMAT_VERSION) + def stateStoreEncodingFormat: String = getConf(STREAMING_STATE_STORE_ENCODING_FORMAT) + def checkpointRenamedFileCheck: Boolean = getConf(CHECKPOINT_RENAMEDFILE_CHECK_ENABLED) def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 7da8408f98b0f..585298fa4c993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -20,10 +20,49 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ object StateStoreColumnFamilySchemaUtils { + /** + * Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans + * we want to use big-endian encoding, so we need to convert the source schema to replace these + * types with BinaryType. + * + * @param schema The schema to convert + * @param ordinals If non-empty, only convert fields at these ordinals. + * If empty, convert all fields. + */ + def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = { + val ordinalSet = ordinals.toSet + + StructType(schema.fields.zipWithIndex.flatMap { case (field, idx) => + if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) { + // For each numeric field, create two fields: + // 1. Byte marker for null, positive, or negative values + // 2. The original numeric value in big-endian format + // Byte type is converted to Int in Avro, which doesn't work for us as Avro + // uses zig-zag encoding as opposed to big-endian for Ints + Seq( + StructField(s"${field.name}_marker", BinaryType, nullable = false), + field.copy(name = s"${field.name}_value", BinaryType) + ) + } else { + Seq(field) + } + }) + } + + private def isFixedSize(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType => true + case _ => false + } + + def getTtlColFamilyName(stateName: String): String = { + "$ttl_" + stateName + } + def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bd501c9357234..44202bb0d2944 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -715,6 +715,7 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + val RUN_ID_KEY = "sql.streaming.runId" val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" val IO_EXCEPTION_NAMES = Seq( classOf[InterruptedException].getName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 4c7a226e0973f..f39022c1f53a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -17,13 +17,21 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.ByteArrayOutputStream import java.lang.Double.{doubleToRawLongBits, longBitsToDouble} import java.lang.Float.{floatToRawIntBits, intBitsToFloat} import java.nio.{ByteBuffer, ByteOrder} +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.io.{DecoderFactory, EncoderFactory} + import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -49,6 +57,7 @@ abstract class RocksDBKeyStateEncoderBase( def offsetForColFamilyPrefix: Int = if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0 + val out = new ByteArrayOutputStream /** * Get Byte Array for the virtual column family id that is used as prefix for * key state rows. @@ -89,23 +98,24 @@ abstract class RocksDBKeyStateEncoderBase( } } -object RocksDBStateEncoder { +object RocksDBStateEncoder extends Logging { def getKeyEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, useColumnFamilies: Boolean, - virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = { // Return the key state encoder based on the requested type keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(keySchema) => - new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId) + new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId, avroEnc) case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey, - useColumnFamilies, virtualColFamilyId) + useColumnFamilies, virtualColFamilyId, avroEnc) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => new RangeKeyScanStateEncoder(keySchema, orderingOrdinals, - useColumnFamilies, virtualColFamilyId) + useColumnFamilies, virtualColFamilyId, avroEnc) case _ => throw new IllegalArgumentException(s"Unsupported key state encoder spec: " + @@ -115,11 +125,12 @@ object RocksDBStateEncoder { def getValueEncoder( valueSchema: StructType, - useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = { + useMultipleValuesPerKey: Boolean, + avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = { if (useMultipleValuesPerKey) { - new MultiValuedStateEncoder(valueSchema) + new MultiValuedStateEncoder(valueSchema, avroEnc) } else { - new SingleValueStateEncoder(valueSchema) + new SingleValueStateEncoder(valueSchema, avroEnc) } } @@ -145,6 +156,26 @@ object RocksDBStateEncoder { encodedBytes } + /** + * This method takes an UnsafeRow, and serializes to a byte array using Avro encoding. + */ + def encodeUnsafeRowToAvro( + row: UnsafeRow, + avroSerializer: AvroSerializer, + valueAvroType: Schema, + out: ByteArrayOutputStream): Array[Byte] = { + // InternalRow -> Avro.GenericDataRecord + val avroData = + avroSerializer.serialize(row) + out.reset() + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + valueAvroType) // Defining Avro writer for this struct type + writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array + encoder.flush() + out.toByteArray + } + def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = { if (bytes != null) { val row = new UnsafeRow(numFields) @@ -154,6 +185,26 @@ object RocksDBStateEncoder { } } + /** + * This method takes a byte array written using Avro encoding, and + * deserializes to an UnsafeRow using the Avro deserializer + */ + def decodeFromAvroToUnsafeRow( + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + val reader = new GenericDatumReader[Any](valueAvroType) + val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0, valueBytes.length, null) + // bytes -> Avro.GenericDataRecord + val genericData = reader.read(null, decoder) + // Avro.GenericDataRecord -> InternalRow + val internalRow = avroDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] + // InternalRow -> UnsafeRow + valueProj.apply(internalRow) + } + def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { if (bytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. @@ -174,16 +225,20 @@ object RocksDBStateEncoder { * @param keySchema - schema of the key to be encoded * @param numColsPrefixKey - number of columns to be used for prefix key * @param useColumnFamilies - if column family is enabled for this encoder + * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will + * be defined */ class PrefixKeyScanStateEncoder( keySchema: StructType, numColsPrefixKey: Int, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { keySchema.zipWithIndex.take(numColsPrefixKey) } @@ -203,6 +258,18 @@ class PrefixKeyScanStateEncoder( UnsafeProjection.create(refs) } + // Prefix Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey)) + private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) + private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) + + // Remaining Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey)) + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema) + // This is quite simple to do - just bind sequentially, as we don't change the order. private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema) @@ -210,9 +277,24 @@ class PrefixKeyScanStateEncoder( private val joinedRowOnKey = new JoinedRow() override def encodeKey(row: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row)) - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) - + val (prefixKeyEncoded, remainingEncoded) = if (usingAvroEncoding) { + ( + encodeUnsafeRowToAvro( + extractPrefixKey(row), + avroEnc.get.keySerializer, + prefixKeyAvroType, + out + ), + encodeUnsafeRowToAvro( + remainingKeyProjection(row), + avroEnc.get.suffixKeySerializer.get, + remainingKeyAvroType, + out + ) + ) + } else { + (encodeUnsafeRow(extractPrefixKey(row)), encodeUnsafeRow(remainingKeyProjection(row))) + } val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + remainingEncoded.length + 4 ) @@ -243,9 +325,25 @@ class PrefixKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - numColsPrefixKey) + val (prefixKeyDecoded, remainingKeyDecoded) = if (usingAvroEncoding) { + ( + decodeFromAvroToUnsafeRow( + prefixKeyEncoded, + avroEnc.get.keyDeserializer, + prefixKeyAvroType, + prefixKeyProj + ), + decodeFromAvroToUnsafeRow( + remainingKeyEncoded, + avroEnc.get.suffixKeyDeserializer.get, + remainingKeyAvroType, + remainingKeyProj + ) + ) + } else { + (decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey), + decodeToUnsafeRow(remainingKeyEncoded, numFields = keySchema.length - numColsPrefixKey)) + } restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) } @@ -255,7 +353,11 @@ class PrefixKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(prefixKey) + val prefixKeyEncoded = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(prefixKey, avroEnc.get.keySerializer, prefixKeyAvroType, out) + } else { + encodeUnsafeRow(prefixKey) + } val (prefix, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + 4 ) @@ -299,13 +401,16 @@ class PrefixKeyScanStateEncoder( * @param keySchema - schema of the key to be encoded * @param orderingOrdinals - the ordinals for which the range scan is constructed * @param useColumnFamilies - if column family is enabled for this encoder + * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will + * be defined */ class RangeKeyScanStateEncoder( keySchema: StructType, orderingOrdinals: Seq[Int], useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { import RocksDBStateEncoder._ @@ -374,6 +479,22 @@ class RangeKeyScanStateEncoder( UnsafeProjection.create(refs) } + private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray)) + + private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema) + + private val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema) + + // Existing remainder key schema stuff + private val remainingKeySchema = StructType( + 0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_)) + ) + + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + + private val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) + // Reusable objects private val joinedRowOnKey = new JoinedRow() @@ -563,13 +684,272 @@ class RangeKeyScanStateEncoder( writer.getRow() } + /** + * Encodes an UnsafeRow into an Avro-compatible byte array format for range scan operations. + * + * This method transforms row data into a binary format that preserves ordering when + * used in range scans. + * For each field in the row: + * - A marker byte is written to indicate null status or sign (for numeric types) + * - The value is written in big-endian format + * + * Special handling is implemented for: + * - Null values: marked with nullValMarker followed by zero bytes + * - Negative numbers: marked with negativeValMarker + * - Floating point numbers: bit manipulation to handle sign and NaN values correctly + * + * @param row The UnsafeRow to encode + * @param avroType The Avro schema defining the structure for encoding + * @return Array[Byte] containing the Avro-encoded data that preserves ordering for range scans + * @throws UnsupportedOperationException if a field's data type is not supported for range + * scan encoding + */ + def encodePrefixKeyForRangeScan( + row: UnsafeRow, + avroType: Schema): Array[Byte] = { + val record = new GenericData.Record(avroType) + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val value = row.get(idx, field.dataType) + + // Create marker byte buffer + val markerBuffer = ByteBuffer.allocate(1) + markerBuffer.order(ByteOrder.BIG_ENDIAN) + + if (value == null) { + markerBuffer.put(nullValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + record.put(fieldIdx + 1, ByteBuffer.wrap(new Array[Byte](field.dataType.defaultSize))) + } else { + field.dataType match { + case BooleanType => + markerBuffer.put(positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + val valueBuffer = ByteBuffer.allocate(1) + valueBuffer.put(if (value.asInstanceOf[Boolean]) 1.toByte else 0.toByte) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case ByteType => + val byteVal = value.asInstanceOf[Byte] + markerBuffer.put(if (byteVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(1) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.put(byteVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case ShortType => + val shortVal = value.asInstanceOf[Short] + markerBuffer.put(if (shortVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(2) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putShort(shortVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case IntegerType => + val intVal = value.asInstanceOf[Int] + markerBuffer.put(if (intVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(4) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putInt(intVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case LongType => + val longVal = value.asInstanceOf[Long] + markerBuffer.put(if (longVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(8) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putLong(longVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case FloatType => + val floatVal = value.asInstanceOf[Float] + val rawBits = floatToRawIntBits(floatVal) + markerBuffer.put(if ((rawBits & floatSignBitMask) != 0) { + negativeValMarker + } else { + positiveValMarker + }) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(4) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & floatSignBitMask) != 0) { + val updatedVal = rawBits ^ floatFlipBitMask + valueBuffer.putFloat(intBitsToFloat(updatedVal)) + } else { + valueBuffer.putFloat(floatVal) + } + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case DoubleType => + val doubleVal = value.asInstanceOf[Double] + val rawBits = doubleToRawLongBits(doubleVal) + markerBuffer.put(if ((rawBits & doubleSignBitMask) != 0) { + negativeValMarker + } else { + positiveValMarker + }) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(8) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & doubleSignBitMask) != 0) { + val updatedVal = rawBits ^ doubleFlipBitMask + valueBuffer.putDouble(longBitsToDouble(updatedVal)) + } else { + valueBuffer.putDouble(doubleVal) + } + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case _ => throw new UnsupportedOperationException( + s"Range scan encoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + out.reset() + val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType) + val encoder = EncoderFactory.get().binaryEncoder(out, null) + writer.write(record, encoder) + encoder.flush() + out.toByteArray + } + + /** + * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan operations. + * + * This method reverses the encoding process performed by encodePrefixKeyForRangeScan: + * - Reads the marker byte to determine null status or sign + * - Reconstructs the original values from big-endian format + * - Handles special cases for floating point numbers by reversing bit manipulations + * + * The decoding process preserves the original data types and values, including: + * - Null values marked by nullValMarker + * - Sign information for numeric types + * - Proper restoration of negative floating point values + * + * @param bytes The Avro-encoded byte array to decode + * @param avroType The Avro schema defining the structure for decoding + * @return UnsafeRow containing the decoded data + * @throws UnsupportedOperationException if a field's data type is not supported for range + * scan decoding + */ + def decodePrefixKeyForRangeScan( + bytes: Array[Byte], + avroType: Schema): UnsafeRow = { + + val reader = new GenericDatumReader[GenericRecord](avroType) + val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) + val record = reader.read(null, decoder) + + val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length) + rowWriter.resetRowWriter() + + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + + val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array() + val markerBuf = ByteBuffer.wrap(markerBytes) + markerBuf.order(ByteOrder.BIG_ENDIAN) + val marker = markerBuf.get() + + if (marker == nullValMarker) { + rowWriter.setNullAt(idx) + } else { + field.dataType match { + case BooleanType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + rowWriter.write(idx, bytes(0) == 1) + + case ByteType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.get()) + + case ShortType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getShort()) + + case IntegerType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getInt()) + + case LongType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getLong()) + + case FloatType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + if (marker == negativeValMarker) { + val floatVal = valueBuf.getFloat + val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask + rowWriter.write(idx, intBitsToFloat(updatedVal)) + } else { + rowWriter.write(idx, valueBuf.getFloat()) + } + + case DoubleType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + if (marker == negativeValMarker) { + val doubleVal = valueBuf.getDouble + val updatedVal = doubleToRawLongBits(doubleVal) ^ doubleFlipBitMask + rowWriter.write(idx, longBitsToDouble(updatedVal)) + } else { + rowWriter.write(idx, valueBuf.getDouble()) + } + + case _ => throw new UnsupportedOperationException( + s"Range scan decoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + rowWriter.getRow() + } + override def encodeKey(row: UnsafeRow): Array[Byte] = { // This prefix key has the columns specified by orderingOrdinals val prefixKey = extractPrefixKey(row) - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = if (avroEnc.isDefined) { + encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType) + } else { + encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + } val result = if (orderingOrdinals.length < keySchema.length) { - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) + val remainingEncoded = if (avroEnc.isDefined) { + encodeUnsafeRowToAvro( + remainingKeyProjection(row), + avroEnc.get.keySerializer, + remainingKeyAvroType, + out + ) + } else { + encodeUnsafeRow(remainingKeyProjection(row)) + } val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( rangeScanKeyEncoded.length + remainingEncoded.length + 4 ) @@ -606,9 +986,12 @@ class RangeKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4, prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) - val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded, - numFields = orderingOrdinals.length) - val prefixKeyDecoded = decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan) + val prefixKeyDecoded = if (avroEnc.isDefined) { + decodePrefixKeyForRangeScan(prefixKeyEncoded, rangeScanAvroType) + } else { + decodePrefixKeyForRangeScan(decodeToUnsafeRow(prefixKeyEncoded, + numFields = orderingOrdinals.length)) + } if (orderingOrdinals.length < keySchema.length) { // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes @@ -620,8 +1003,14 @@ class RangeKeyScanStateEncoder( remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - orderingOrdinals.length) + val remainingKeyDecoded = if (avroEnc.isDefined) { + decodeFromAvroToUnsafeRow(remainingKeyEncoded, + avroEnc.get.keyDeserializer, + remainingKeyAvroType, remainingKeyAvroProjection) + } else { + decodeToUnsafeRow(remainingKeyEncoded, + numFields = keySchema.length - orderingOrdinals.length) + } val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded) val restored = restoreKeyProjection(joined) @@ -634,7 +1023,11 @@ class RangeKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = if (avroEnc.isDefined) { + encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType) + } else { + encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + } val (prefix, startingOffset) = encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4) Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length) @@ -653,6 +1046,7 @@ class RangeKeyScanStateEncoder( * It uses the first byte of the generated byte array to store the version the describes how the * row is encoded in the rest of the byte array. Currently, the default version is 0, * + * If the avroEnc is specified, we are using Avro encoding for this column family's keys * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ] * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, @@ -661,19 +1055,27 @@ class RangeKeyScanStateEncoder( class NoPrefixKeyStateEncoder( keySchema: StructType, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { import RocksDBStateEncoder._ // Reusable objects + private val usingAvroEncoding = avroEnc.isDefined private val keyRow = new UnsafeRow(keySchema.size) + private lazy val keyAvroType = SchemaConverters.toAvroType(keySchema) + private val keyProj = UnsafeProjection.create(keySchema) override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { encodeUnsafeRow(row) } else { - val bytesToEncode = row.getBytes + // If avroEnc is defined, we know that we need to use Avro to + // encode this UnsafeRow to Avro bytes + val bytesToEncode = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.keySerializer, keyAvroType, out) + } else row.getBytes val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES @@ -697,11 +1099,21 @@ class NoPrefixKeyStateEncoder( if (useColumnFamilies) { if (keyBytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. - keyRow.pointTo( - keyBytes, - decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, - keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) - keyRow + if (usingAvroEncoding) { + val dataLength = keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - + VIRTUAL_COL_FAMILY_PREFIX_BYTES + val avroBytes = new Array[Byte](dataLength) + Platform.copyMemory( + keyBytes, decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + avroBytes, Platform.BYTE_ARRAY_OFFSET, dataLength) + decodeFromAvroToUnsafeRow(avroBytes, avroEnc.get.keyDeserializer, keyAvroType, keyProj) + } else { + keyRow.pointTo( + keyBytes, + decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) + keyRow + } } else { null } @@ -727,17 +1139,28 @@ class NoPrefixKeyStateEncoder( * This encoder supports RocksDB StringAppendOperator merge operator. Values encoded can be * merged in RocksDB using merge operation, and all merged values can be read using decodeValues * operation. + * If the avroEnc is specified, we are using Avro encoding for this column family's values */ -class MultiValuedStateEncoder(valueSchema: StructType) +class MultiValuedStateEncoder( + valueSchema: StructType, + avroEnc: Option[AvroEncoder] = None) extends RocksDBValueStateEncoder with Logging { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined // Reusable objects + private val out = new ByteArrayOutputStream private val valueRow = new UnsafeRow(valueSchema.size) + private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema) + private val valueProj = UnsafeProjection.create(valueSchema) override def encodeValue(row: UnsafeRow): Array[Byte] = { - val bytes = encodeUnsafeRow(row) + val bytes = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) + } else { + encodeUnsafeRow(row) + } val numBytes = bytes.length val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length) @@ -756,7 +1179,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) val encodedValue = new Array[Byte](numBytes) Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET, encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes) - decodeToUnsafeRow(encodedValue, valueRow) + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + encodedValue, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(encodedValue, valueRow) + } } } @@ -782,7 +1210,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) pos += numBytes pos += 1 // eat the delimiter character - decodeToUnsafeRow(encodedValue, valueRow) + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + encodedValue, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(encodedValue, valueRow) + } } } } @@ -802,16 +1235,29 @@ class MultiValuedStateEncoder(valueSchema: StructType) * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, * then the generated array byte will be N+1 bytes. + * If the avroEnc is specified, we are using Avro encoding for this column family's values */ -class SingleValueStateEncoder(valueSchema: StructType) - extends RocksDBValueStateEncoder { +class SingleValueStateEncoder( + valueSchema: StructType, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBValueStateEncoder with Logging { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined // Reusable objects + private val out = new ByteArrayOutputStream private val valueRow = new UnsafeRow(valueSchema.size) + private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema) + private val valueProj = UnsafeProjection.create(valueSchema) - override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) + override def encodeValue(row: UnsafeRow): Array[Byte] = { + if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) + } else { + encodeUnsafeRow(row) + } + } /** * Decode byte array for a value to a UnsafeRow. @@ -820,7 +1266,15 @@ class SingleValueStateEncoder(valueSchema: StructType) * the given byte array. */ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { - decodeToUnsafeRow(valueBytes, valueRow) + if (valueBytes == null) { + return null + } + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(valueBytes, valueRow) + } } override def supportsMultipleValuesPerKey: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1fc6ab5910c6c..e5a4175aeec1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.util.concurrent.ConcurrentHashMap +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.util.control.NonFatal +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -29,11 +31,12 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{NonFateSharingCache, Utils} private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable @@ -74,10 +77,17 @@ private[sql] class RocksDBStateStoreProvider isInternal: Boolean = false): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) + // Create cache key using store ID to avoid collisions + val avroEncCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + + s"${stateStoreId.partitionId}_$colFamilyName" + + val avroEnc = getAvroEnc( + stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema) + keyValueEncoderMap.putIfAbsent(colFamilyName, (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies, - Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(valueSchema, - useMultipleValuesPerKey))) + Some(newColFamilyId), avroEnc), RocksDBStateEncoder.getValueEncoder(valueSchema, + useMultipleValuesPerKey, avroEnc))) } override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { @@ -364,6 +374,7 @@ private[sql] class RocksDBStateStoreProvider this.storeConf = storeConf this.hadoopConf = hadoopConf this.useColumnFamilies = useColumnFamilies + this.stateStoreEncoding = storeConf.stateStoreEncodingFormat if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -377,10 +388,17 @@ private[sql] class RocksDBStateStoreProvider defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME)) } + val colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME + // Create cache key using store ID to avoid collisions + val avroEncCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + + s"${stateStoreId.partitionId}_$colFamilyName" + val avroEnc = getAvroEnc( + stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema) + keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, - useColumnFamilies, defaultColFamilyId), - RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey))) + useColumnFamilies, defaultColFamilyId, avroEnc), + RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey, avroEnc))) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -458,6 +476,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var useColumnFamilies: Boolean = _ + @volatile private var stateStoreEncoding: String = _ private[sql] lazy val rocksDB = { val dfsRootDir = stateStoreId.storeCheckpointLocation().toString @@ -593,6 +612,93 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 + private val MAX_AVRO_ENCODERS_IN_CACHE = 1000 + // Add the cache at companion object level so it persists across provider instances + private val avroEncoderMap: NonFateSharingCache[String, AvroEncoder] = { + val guavaCache = CacheBuilder.newBuilder() + .maximumSize(MAX_AVRO_ENCODERS_IN_CACHE) // Adjust size based on your needs + .expireAfterAccess(1, TimeUnit.HOURS) // Optional: Add expiration if needed + .build[String, AvroEncoder]() + + new NonFateSharingCache(guavaCache) + } + + def getAvroEnc( + stateStoreEncoding: String, + avroEncCacheKey: String, + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType): Option[AvroEncoder] = { + + stateStoreEncoding match { + case "avro" => Some( + RocksDBStateStoreProvider.avroEncoderMap.get( + avroEncCacheKey, + new java.util.concurrent.Callable[AvroEncoder] { + override def call(): AvroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) + } + ) + ) + case "unsaferow" => None + } + } + + private def getRunId(hadoopConf: Configuration): String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } + + private def getAvroSerializer(schema: StructType): AvroSerializer = { + val avroType = SchemaConverters.toAvroType(schema) + new AvroSerializer(schema, avroType, nullable = false) + } + + private def getAvroDeserializer(schema: StructType): AvroDeserializer = { + val avroType = SchemaConverters.toAvroType(schema) + val avroOptions = AvroOptions(Map.empty) + new AvroDeserializer(avroType, schema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + } + + private def createAvroEnc( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType + ): AvroEncoder = { + val valueSerializer = getAvroSerializer(valueSchema) + val valueDeserializer = getAvroDeserializer(valueSchema) + val keySchema = keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(schema) => + schema + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + StructType(schema.take(numColsPrefixKey)) + case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => + val remainingSchema = { + 0.until(schema.length).diff(orderingOrdinals).map { ordinal => + schema(ordinal) + } + } + StructType(remainingSchema) + } + val suffixKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + Some(StructType(schema.drop(numColsPrefixKey))) + case _ => None + } + AvroEncoder( + getAvroSerializer(keySchema), + getAvroDeserializer(keySchema), + valueSerializer, + valueDeserializer, + suffixKeySchema.map(getAvroSerializer), + suffixKeySchema.map(getAvroDeserializer) + ) + } + // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 721d72b6a0991..48b15ac04f40b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer} import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} @@ -37,6 +38,30 @@ case class StateSchemaValidationResult( schemaPath: String ) +/** + * An Avro-based encoder used for serializing between UnsafeRow and Avro + * byte arrays in RocksDB state stores. + * + * This encoder is primarily utilized by [[RocksDBStateStoreProvider]] and [[RocksDBStateEncoder]] + * to handle serialization and deserialization of state store data. + * + * @param keySerializer Serializer for converting state store keys to Avro format + * @param keyDeserializer Deserializer for converting Avro-encoded keys back to UnsafeRow + * @param valueSerializer Serializer for converting state store values to Avro format + * @param valueDeserializer Deserializer for converting Avro-encoded values back to UnsafeRow + * @param suffixKeySerializer Optional serializer for handling suffix keys in Avro format + * @param suffixKeyDeserializer Optional deserializer for converting Avro-encoded suffix + * keys back to UnsafeRow + */ +case class AvroEncoder( + keySerializer: AvroSerializer, + keyDeserializer: AvroDeserializer, + valueSerializer: AvroSerializer, + valueDeserializer: AvroDeserializer, + suffixKeySerializer: Option[AvroSerializer] = None, + suffixKeyDeserializer: Option[AvroDeserializer] = None +) extends Serializable + // Used to represent the schema of a column family in the state store case class StateStoreColFamilySchema( colFamilyName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 72bc3ca33054d..e2b93c147891d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -37,10 +37,22 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamExecution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} +sealed trait StateStoreEncoding { + override def toString: String = this match { + case StateStoreEncoding.UnsafeRow => "unsaferow" + case StateStoreEncoding.Avro => "avro" + } +} + +object StateStoreEncoding { + case object UnsafeRow extends StateStoreEncoding + case object Avro extends StateStoreEncoding +} + /** * Base trait for a versioned key-value store which provides read operations. Each instance of a * `ReadStateStore` represents a specific version of state data, and such instances are created @@ -769,6 +781,7 @@ object StateStore extends Logging { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey) storeProvider.getStore(version, stateStoreCkptId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index c8af395e996d8..9d26bf8fdf2e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -83,6 +83,9 @@ class StateStoreConf( /** The interval of maintenance tasks. */ val maintenanceInterval = sqlConf.streamingMaintenanceInterval + /** The interval of maintenance tasks. */ + val stateStoreEncodingFormat = sqlConf.stateStoreEncodingFormat + /** * When creating new state store checkpoint, which format version to use. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index baab6327b35c1..af64f563cf7b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} import org.apache.spark.sql.functions.{col, explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} @@ -126,7 +126,7 @@ class SessionGroupsStatefulProcessorWithTTL extends * Test suite to verify integration of state data source reader with the transformWithState operator */ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled with AlsoTestWithEncodingTypes { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e1bd9dd38066b..0abdcadefbd55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.util.Utils @ExtendedSQLTest class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider] with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes with SharedSparkSession with BeforeAndAfter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 637eb49130305..61ca8e7c32f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -86,6 +86,19 @@ trait RocksDBStateStoreChangelogCheckpointingTestUtil { } } +trait AlsoTestWithEncodingTypes extends SQLTestUtils { + override protected def test(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + Seq("unsaferow", "avro").foreach { encoding => + super.test(s"$testName (encoding = $encoding)", testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody + } + } + } + } +} + trait AlsoTestWithChangelogCheckpointingEnabled extends SQLTestUtils with RocksDBStateStoreChangelogCheckpointingTestUtil { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 55d08cd8f12a7..8984d9b0845b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -423,7 +423,7 @@ class ValueStateSuite extends StateVariableSuiteBase { * types (ValueState, ListState, MapState) used in arbitrary stateful operators. */ abstract class StateVariableSuiteBase extends SharedSparkSession - with BeforeAndAfter { + with BeforeAndAfter with AlsoTestWithEncodingTypes { before { StateStore.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 88862e2ad0791..5d88db0d01ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputRow(key: String, action: String, value: String) @@ -127,7 +127,8 @@ class ToggleSaveAndEmitProcessor } class TransformWithListStateSuite extends StreamTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ test("test appending null value in list state throw exception") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 76c5cbeee424b..ec6ff4fcceb67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputMapRow(key: String, action: String, value: (String, String)) @@ -81,7 +81,8 @@ class TestMapStateProcessor * operators such as transformWithState. */ class TransformWithMapStateSuite extends StreamTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 505775d4f6a9b..91a47645f4179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -429,11 +429,12 @@ class SleepingTimerProcessor extends StatefulProcessor[String, String, String] { * Class that adds tests for transformWithState stateful streaming operator */ class TransformWithStateSuite extends StateStoreMetricsTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled with AlsoTestWithEncodingTypes { import testImplicits._ - test("transformWithState - streaming with rocksdb and invalid processor should fail") { + test("transformWithState - streaming with rocksdb and" + + " invalid processor should fail") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -688,7 +689,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and event time based timer") { + test("transformWithState - streaming with rocksdb and event " + + "time based timer") { val inputData = MemoryStream[(String, Int)] val result = inputData.toDS() @@ -778,7 +780,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - test("Use statefulProcessor without transformWithState - handle should be absent") { + test("Use statefulProcessor without transformWithState -" + + " handle should be absent") { val processor = new RunningCountStatefulProcessor() val ex = intercept[Exception] { processor.getHandle @@ -1034,7 +1037,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { + test("transformWithState - verify StateSchemaV3 writes " + + "correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1605,7 +1609,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify that schema file is kept after metadata is purged") { + test("transformWithState - verify that schema file " + + "is kept after metadata is purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 2ddf69aa49e04..75fda9630779e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,7 +41,8 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest { + extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, InputEvent, OutputEvent] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 21c3beb79314c..b19c126c7386b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -262,7 +262,8 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") { + test("verify StateSchemaV3 writes correct SQL " + + "schema of key/value and with TTL") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key ->