From e939fa8312c92ab75cc52265e86fd816d842f9a0 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 23 Nov 2024 13:36:34 -0800 Subject: [PATCH 01/17] init --- .../apache/spark/sql/internal/SQLConf.scala | 13 + .../StateStoreColumnFamilySchemaUtils.scala | 41 +- .../execution/streaming/StreamExecution.scala | 1 + .../streaming/state/RocksDBStateEncoder.scala | 965 +++++++++++++----- .../state/RocksDBStateStoreProvider.scala | 138 ++- .../StateSchemaCompatibilityChecker.scala | 12 + .../streaming/state/StateStore.scala | 17 +- .../streaming/state/StateStoreConf.scala | 4 + ...ateDataSourceTransformWithStateSuite.scala | 3 +- .../state/RocksDBStateStoreSuite.scala | 1 + .../streaming/state/RocksDBSuite.scala | 13 + .../streaming/state/ValueStateSuite.scala | 2 +- .../TransformWithListStateSuite.scala | 5 +- .../TransformWithMapStateSuite.scala | 5 +- .../TransformWithStateChainingSuite.scala | 5 +- .../streaming/TransformWithStateSuite.scala | 3 +- .../streaming/TransformWithStateTTLTest.scala | 5 +- 17 files changed, 930 insertions(+), 303 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 123759c6c8b80..f2c1d11a226fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2230,6 +2230,17 @@ object SQLConf { .intConf .createWithDefault(1) + val STREAMING_STATE_STORE_ENCODING_FORMAT = + buildConf("spark.sql.streaming.stateStore.encodingFormat") + .doc("The encoding format used for stateful operators to store information " + + "in the state store") + .version("4.0.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValue(v => Set("unsaferow", "avro").contains(v), + "Valid values are 'unsaferow' and 'avro'") + .createWithDefault("unsaferow") + val STATE_STORE_COMPRESSION_CODEC = buildConf("spark.sql.streaming.stateStore.compression.codec") .internal() @@ -5607,6 +5618,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreCheckpointFormatVersion: Int = getConf(STATE_STORE_CHECKPOINT_FORMAT_VERSION) + def stateStoreEncodingFormat: String = getConf(STREAMING_STATE_STORE_ENCODING_FORMAT) + def checkpointRenamedFileCheck: Boolean = getConf(CHECKPOINT_RENAMEDFILE_CHECK_ENABLED) def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 7da8408f98b0f..585298fa4c993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -20,10 +20,49 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ object StateStoreColumnFamilySchemaUtils { + /** + * Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans + * we want to use big-endian encoding, so we need to convert the source schema to replace these + * types with BinaryType. + * + * @param schema The schema to convert + * @param ordinals If non-empty, only convert fields at these ordinals. + * If empty, convert all fields. + */ + def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = { + val ordinalSet = ordinals.toSet + + StructType(schema.fields.zipWithIndex.flatMap { case (field, idx) => + if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) { + // For each numeric field, create two fields: + // 1. Byte marker for null, positive, or negative values + // 2. The original numeric value in big-endian format + // Byte type is converted to Int in Avro, which doesn't work for us as Avro + // uses zig-zag encoding as opposed to big-endian for Ints + Seq( + StructField(s"${field.name}_marker", BinaryType, nullable = false), + field.copy(name = s"${field.name}_value", BinaryType) + ) + } else { + Seq(field) + } + }) + } + + private def isFixedSize(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType => true + case _ => false + } + + def getTtlColFamilyName(stateName: String): String = { + "$ttl_" + stateName + } + def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bd501c9357234..44202bb0d2944 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -715,6 +715,7 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + val RUN_ID_KEY = "sql.streaming.runId" val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" val IO_EXCEPTION_NAMES = Seq( classOf[InterruptedException].getName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 4c7a226e0973f..d718ec54e6ecb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -17,30 +17,660 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.ByteArrayOutputStream import java.lang.Double.{doubleToRawLongBits, longBitsToDouble} import java.lang.Float.{floatToRawIntBits, intBitsToFloat} import java.nio.{ByteBuffer, ByteOrder} +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.io.{DecoderFactory, EncoderFactory} + import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform -sealed trait RocksDBKeyStateEncoder { - def supportPrefixKeyScan: Boolean - def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] - def encodeKey(row: UnsafeRow): Array[Byte] - def decodeKey(keyBytes: Array[Byte]): UnsafeRow - def getColumnFamilyIdBytes(): Array[Byte] -} +sealed trait RocksDBKeyStateEncoder { + def supportPrefixKeyScan: Boolean + def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] + def encodeKey(row: UnsafeRow): Array[Byte] + def decodeKey(keyBytes: Array[Byte]): UnsafeRow + def getColumnFamilyIdBytes(): Array[Byte] +} + +sealed trait RocksDBValueStateEncoder { + def supportsMultipleValuesPerKey: Boolean + def encodeValue(row: UnsafeRow): Array[Byte] + def decodeValue(valueBytes: Array[Byte]): UnsafeRow + def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] +} + +trait StateEncoder { + def encodeKey(row: UnsafeRow): Array[Byte] + def encodeRemainingKey(row: UnsafeRow): Array[Byte] + def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] + def encodeValue(row: UnsafeRow): Array[Byte] + + def decodeKey(bytes: Array[Byte]): UnsafeRow + def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow + def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow + def decodeValue(bytes: Array[Byte]): UnsafeRow +} + +abstract class RocksDBDataEncoder( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType) extends StateEncoder { + + val keySchema = keyStateEncoderSpec.keySchema + val reusedKeyRow = new UnsafeRow(keyStateEncoderSpec.keySchema.length) + val reusedValueRow = new UnsafeRow(valueSchema.length) + + // bit masks used for checking sign or flipping all bits for negative float/double values + val floatFlipBitMask = 0xFFFFFFFF + val floatSignBitMask = 0x80000000 + + val doubleFlipBitMask = 0xFFFFFFFFFFFFFFFFL + val doubleSignBitMask = 0x8000000000000000L + + // Byte markers used to identify whether the value is null, negative or positive + // To ensure sorted ordering, we use the lowest byte value for negative numbers followed by + // positive numbers and then null values. + val negativeValMarker: Byte = 0x00.toByte + val positiveValMarker: Byte = 0x01.toByte + val nullValMarker: Byte = 0x02.toByte + + /** + * Encode the UnsafeRow of N bytes as a N+1 byte array. + * @note This creates a new byte array and memcopies the UnsafeRow to the new array. + */ + def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = { + val bytesToEncode = row.getBytes + val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) + Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) + // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. + Platform.copyMemory( + bytesToEncode, Platform.BYTE_ARRAY_OFFSET, + encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + bytesToEncode.length) + encodedBytes + } + + def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = { + if (bytes != null) { + val row = new UnsafeRow(numFields) + decodeToUnsafeRow(bytes, row) + } else { + null + } + } + + def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { + if (bytes != null) { + // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. + reusedRow.pointTo( + bytes, + Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) + reusedRow + } else { + null + } + } +} + +class UnsafeRowStateEncoder( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) { + + override def encodeKey(row: UnsafeRow): Array[Byte] = { + encodeUnsafeRow(row) + } + + override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = { + encodeUnsafeRow(row) + } + + override def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] = { + assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec]) + val rsk = keyStateEncoderSpec.asInstanceOf[RangeKeyScanStateEncoderSpec] + val rangeScanKeyFieldsWithOrdinal = rsk.orderingOrdinals.map { ordinal => + val field = rsk.keySchema(ordinal) + (field, ordinal) + } + val writer = new UnsafeRowWriter(rsk.orderingOrdinals.length) + writer.resetRowWriter() + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val value = row.get(idx, field.dataType) + // Note that we cannot allocate a smaller buffer here even if the value is null + // because the effective byte array is considered variable size and needs to have + // the same size across all rows for the ordering to work as expected. + val bbuf = ByteBuffer.allocate(field.dataType.defaultSize + 1) + bbuf.order(ByteOrder.BIG_ENDIAN) + if (value == null) { + bbuf.put(nullValMarker) + writer.write(idx, bbuf.array()) + } else { + field.dataType match { + case BooleanType => + case ByteType => + val byteVal = value.asInstanceOf[Byte] + val signCol = if (byteVal < 0) { + negativeValMarker + } else { + positiveValMarker + } + bbuf.put(signCol) + bbuf.put(byteVal) + writer.write(idx, bbuf.array()) + + case ShortType => + val shortVal = value.asInstanceOf[Short] + val signCol = if (shortVal < 0) { + negativeValMarker + } else { + positiveValMarker + } + bbuf.put(signCol) + bbuf.putShort(shortVal) + writer.write(idx, bbuf.array()) + + case IntegerType => + val intVal = value.asInstanceOf[Int] + val signCol = if (intVal < 0) { + negativeValMarker + } else { + positiveValMarker + } + bbuf.put(signCol) + bbuf.putInt(intVal) + writer.write(idx, bbuf.array()) + + case LongType => + val longVal = value.asInstanceOf[Long] + val signCol = if (longVal < 0) { + negativeValMarker + } else { + positiveValMarker + } + bbuf.put(signCol) + bbuf.putLong(longVal) + writer.write(idx, bbuf.array()) + + case FloatType => + val floatVal = value.asInstanceOf[Float] + val rawBits = floatToRawIntBits(floatVal) + // perform sign comparison using bit manipulation to ensure NaN values are handled + // correctly + if ((rawBits & floatSignBitMask) != 0) { + // for negative values, we need to flip all the bits to ensure correct ordering + val updatedVal = rawBits ^ floatFlipBitMask + bbuf.put(negativeValMarker) + // convert the bits back to float + bbuf.putFloat(intBitsToFloat(updatedVal)) + } else { + bbuf.put(positiveValMarker) + bbuf.putFloat(floatVal) + } + writer.write(idx, bbuf.array()) + + case DoubleType => + val doubleVal = value.asInstanceOf[Double] + val rawBits = doubleToRawLongBits(doubleVal) + // perform sign comparison using bit manipulation to ensure NaN values are handled + // correctly + if ((rawBits & doubleSignBitMask) != 0) { + // for negative values, we need to flip all the bits to ensure correct ordering + val updatedVal = rawBits ^ doubleFlipBitMask + bbuf.put(negativeValMarker) + // convert the bits back to double + bbuf.putDouble(longBitsToDouble(updatedVal)) + } else { + bbuf.put(positiveValMarker) + bbuf.putDouble(doubleVal) + } + writer.write(idx, bbuf.array()) + } + } + } + encodeUnsafeRow(writer.getRow()) + } + + override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) + + override def decodeKey(bytes: Array[Byte]): UnsafeRow = { + keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(_) => + decodeToUnsafeRow(bytes, reusedKeyRow) + case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) => + decodeToUnsafeRow(bytes, numFields = numColsPrefixKey) + case _ => null + } + } + + override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = null + + override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = { + assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec]) + val rsk = keyStateEncoderSpec.asInstanceOf[RangeKeyScanStateEncoderSpec] + val writer = new UnsafeRowWriter(rsk.orderingOrdinals.length) + val rangeScanKeyFieldsWithOrdinal = rsk.orderingOrdinals.map { ordinal => + val field = rsk.keySchema(ordinal) + (field, ordinal) + } + writer.resetRowWriter() + val row = decodeToUnsafeRow(bytes, numFields = rsk.orderingOrdinals.length) + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + + val value = row.getBinary(idx) + val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]]) + bbuf.order(ByteOrder.BIG_ENDIAN) + val isNullOrSignCol = bbuf.get() + if (isNullOrSignCol == nullValMarker) { + // set the column to null and skip reading the next byte(s) + writer.setNullAt(idx) + } else { + field.dataType match { + case BooleanType => + case ByteType => + writer.write(idx, bbuf.get) + + case ShortType => + writer.write(idx, bbuf.getShort) + + case IntegerType => + writer.write(idx, bbuf.getInt) + + case LongType => + writer.write(idx, bbuf.getLong) + + case FloatType => + if (isNullOrSignCol == negativeValMarker) { + // if the number is negative, get the raw binary bits for the float + // and flip the bits back + val updatedVal = floatToRawIntBits(bbuf.getFloat) ^ floatFlipBitMask + writer.write(idx, intBitsToFloat(updatedVal)) + } else { + writer.write(idx, bbuf.getFloat) + } + + case DoubleType => + if (isNullOrSignCol == negativeValMarker) { + // if the number is negative, get the raw binary bits for the double + // and flip the bits back + val updatedVal = doubleToRawLongBits(bbuf.getDouble) ^ doubleFlipBitMask + writer.write(idx, longBitsToDouble(updatedVal)) + } else { + writer.write(idx, bbuf.getDouble) + } + } + } + } + writer.getRow() + } + + override def decodeValue(bytes: Array[Byte]): UnsafeRow = decodeToUnsafeRow(bytes, reusedValueRow) +} + +class AvroStateEncoder( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType, + avroEncoder: AvroEncoder) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) + with Logging { + + private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema) + private lazy val keyProj = UnsafeProjection.create(keySchema) + private lazy val valueProj = UnsafeProjection.create(valueSchema) + + private lazy val valueAvroType: Schema = SchemaConverters.toAvroType(valueSchema) + + // Prefix Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private lazy val prefixKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => + StructType(keySchema.take (numColsPrefixKey)) + case _ => null + } + + private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) + private lazy val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) + + private lazy val rangeScanKeyFieldsWithOrdinal = keyStateEncoderSpec match { + case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => + orderingOrdinals.map { ordinal => + val field = keySchema(ordinal) + (field, ordinal) + } + case _ => null + } + + private lazy val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray)) + + private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema) + + private lazy val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema) + + // Existing remainder key schema stuff + // Remaining Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val remainingKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => + StructType(keySchema.drop(numColsPrefixKey)) + case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => + StructType(0.until(keySchema.length).diff(orderingOrdinals).map(keySchema(_))) + case _ => null + } + + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + + private lazy val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) + + /** + * This method takes an UnsafeRow, and serializes to a byte array using Avro encoding. + */ + def encodeUnsafeRowToAvro( + row: UnsafeRow, + avroSerializer: AvroSerializer, + valueAvroType: Schema, + out: ByteArrayOutputStream): Array[Byte] = { + // InternalRow -> Avro.GenericDataRecord + val avroData = + avroSerializer.serialize(row) + out.reset() + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + valueAvroType) // Defining Avro writer for this struct type + writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array + encoder.flush() + out.toByteArray + } + + /** + * This method takes a byte array written using Avro encoding, and + * deserializes to an UnsafeRow using the Avro deserializer + */ + def decodeFromAvroToUnsafeRow( + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + if (valueBytes != null) { + val reader = new GenericDatumReader[Any](valueAvroType) + val decoder = DecoderFactory.get().binaryDecoder( + valueBytes, Platform.BYTE_ARRAY_OFFSET, valueBytes.length, null) + // bytes -> Avro.GenericDataRecord + val genericData = reader.read(null, decoder) + // Avro.GenericDataRecord -> InternalRow + val internalRow = avroDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] + // InternalRow -> UnsafeRow + valueProj.apply(internalRow) + } else { + null + } + } + + private val out = new ByteArrayOutputStream + override def encodeKey(row: UnsafeRow): Array[Byte] = { + keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(_) => + encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out) + case PrefixKeyScanStateEncoderSpec(_, _) => + encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out) + case _ => null + } + } + + override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = { + keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(_, _) => + encodeUnsafeRowToAvro(row, avroEncoder.suffixKeySerializer.get, remainingKeyAvroType, out) + case RangeKeyScanStateEncoderSpec(_, _) => + encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, remainingKeyAvroType, out) + case _ => null + } + } + + override def encodePrefixKeyForRangeScan( + row: UnsafeRow): Array[Byte] = { + val record = new GenericData.Record(rangeScanAvroType) + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val value = row.get(idx, field.dataType) + + // Create marker byte buffer + val markerBuffer = ByteBuffer.allocate(1) + markerBuffer.order(ByteOrder.BIG_ENDIAN) + + if (value == null) { + markerBuffer.put(nullValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + record.put(fieldIdx + 1, ByteBuffer.wrap(new Array[Byte](field.dataType.defaultSize))) + } else { + field.dataType match { + case BooleanType => + markerBuffer.put(positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + val valueBuffer = ByteBuffer.allocate(1) + valueBuffer.put(if (value.asInstanceOf[Boolean]) 1.toByte else 0.toByte) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case ByteType => + val byteVal = value.asInstanceOf[Byte] + markerBuffer.put(if (byteVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(1) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.put(byteVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case ShortType => + val shortVal = value.asInstanceOf[Short] + markerBuffer.put(if (shortVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(2) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putShort(shortVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case IntegerType => + val intVal = value.asInstanceOf[Int] + markerBuffer.put(if (intVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(4) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putInt(intVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case LongType => + val longVal = value.asInstanceOf[Long] + markerBuffer.put(if (longVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(8) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putLong(longVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case FloatType => + val floatVal = value.asInstanceOf[Float] + val rawBits = floatToRawIntBits(floatVal) + markerBuffer.put(if ((rawBits & floatSignBitMask) != 0) { + negativeValMarker + } else { + positiveValMarker + }) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(4) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & floatSignBitMask) != 0) { + val updatedVal = rawBits ^ floatFlipBitMask + valueBuffer.putFloat(intBitsToFloat(updatedVal)) + } else { + valueBuffer.putFloat(floatVal) + } + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case DoubleType => + val doubleVal = value.asInstanceOf[Double] + val rawBits = doubleToRawLongBits(doubleVal) + markerBuffer.put(if ((rawBits & doubleSignBitMask) != 0) { + negativeValMarker + } else { + positiveValMarker + }) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(8) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & doubleSignBitMask) != 0) { + val updatedVal = rawBits ^ doubleFlipBitMask + valueBuffer.putDouble(longBitsToDouble(updatedVal)) + } else { + valueBuffer.putDouble(doubleVal) + } + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case _ => throw new UnsupportedOperationException( + s"Range scan encoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + out.reset() + val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType) + val encoder = EncoderFactory.get().binaryEncoder(out, null) + writer.write(record, encoder) + encoder.flush() + out.toByteArray + } + + override def encodeValue(row: UnsafeRow): Array[Byte] = + encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out) + + override def decodeKey(bytes: Array[Byte]): UnsafeRow = { + keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(_) => + decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer, keyAvroType, keyProj) + case PrefixKeyScanStateEncoderSpec(_, _) => + decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer, keyAvroType, prefixKeyProj) + case _ => null + } + } + -sealed trait RocksDBValueStateEncoder { - def supportsMultipleValuesPerKey: Boolean - def encodeValue(row: UnsafeRow): Array[Byte] - def decodeValue(valueBytes: Array[Byte]): UnsafeRow - def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] + override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = { + keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(_, _) => + decodeFromAvroToUnsafeRow(bytes, + avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType, remainingKeyAvroProjection) + case RangeKeyScanStateEncoderSpec(_, _) => + decodeFromAvroToUnsafeRow( + bytes, avroEncoder.keyDeserializer, remainingKeyAvroType, remainingKeyAvroProjection) + case _ => null + } + } + + override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = { + val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType) + val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) + val record = reader.read(null, decoder) + + val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length) + rowWriter.resetRowWriter() + + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + + val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array() + val markerBuf = ByteBuffer.wrap(markerBytes) + markerBuf.order(ByteOrder.BIG_ENDIAN) + val marker = markerBuf.get() + + if (marker == nullValMarker) { + rowWriter.setNullAt(idx) + } else { + field.dataType match { + case BooleanType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + rowWriter.write(idx, bytes(0) == 1) + + case ByteType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.get()) + + case ShortType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getShort()) + + case IntegerType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getInt()) + + case LongType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getLong()) + + case FloatType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + if (marker == negativeValMarker) { + val floatVal = valueBuf.getFloat + val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask + rowWriter.write(idx, intBitsToFloat(updatedVal)) + } else { + rowWriter.write(idx, valueBuf.getFloat()) + } + + case DoubleType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + if (marker == negativeValMarker) { + val doubleVal = valueBuf.getDouble + val updatedVal = doubleToRawLongBits(doubleVal) ^ doubleFlipBitMask + rowWriter.write(idx, longBitsToDouble(updatedVal)) + } else { + rowWriter.write(idx, valueBuf.getDouble()) + } + + case _ => throw new UnsupportedOperationException( + s"Range scan decoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + rowWriter.getRow() + } + + override def decodeValue(bytes: Array[Byte]): UnsafeRow = + decodeFromAvroToUnsafeRow( + bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj) } abstract class RocksDBKeyStateEncoderBase( @@ -91,20 +721,21 @@ abstract class RocksDBKeyStateEncoderBase( object RocksDBStateEncoder { def getKeyEncoder( + stateEncoder: RocksDBDataEncoder, keyStateEncoderSpec: KeyStateEncoderSpec, useColumnFamilies: Boolean, virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = { // Return the key state encoder based on the requested type keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(keySchema) => - new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId) + new NoPrefixKeyStateEncoder(stateEncoder, keySchema, useColumnFamilies, virtualColFamilyId) case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => - new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey, + new PrefixKeyScanStateEncoder(stateEncoder, keySchema, numColsPrefixKey, useColumnFamilies, virtualColFamilyId) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => - new RangeKeyScanStateEncoder(keySchema, orderingOrdinals, + new RangeKeyScanStateEncoder(stateEncoder, keySchema, orderingOrdinals, useColumnFamilies, virtualColFamilyId) case _ => @@ -114,12 +745,13 @@ object RocksDBStateEncoder { } def getValueEncoder( + stateEncoder: RocksDBDataEncoder, valueSchema: StructType, useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = { if (useMultipleValuesPerKey) { - new MultiValuedStateEncoder(valueSchema) + new MultiValuedStateEncoder(stateEncoder, valueSchema) } else { - new SingleValueStateEncoder(valueSchema) + new SingleValueStateEncoder(stateEncoder, valueSchema) } } @@ -128,44 +760,6 @@ object RocksDBStateEncoder { Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId) encodedBytes } - - /** - * Encode the UnsafeRow of N bytes as a N+1 byte array. - * @note This creates a new byte array and memcopies the UnsafeRow to the new array. - */ - def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = { - val bytesToEncode = row.getBytes - val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) - Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) - // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. - Platform.copyMemory( - bytesToEncode, Platform.BYTE_ARRAY_OFFSET, - encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - bytesToEncode.length) - encodedBytes - } - - def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = { - if (bytes != null) { - val row = new UnsafeRow(numFields) - decodeToUnsafeRow(bytes, row) - } else { - null - } - } - - def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { - if (bytes != null) { - // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. - reusedRow.pointTo( - bytes, - Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) - reusedRow - } else { - null - } - } } /** @@ -176,13 +770,12 @@ object RocksDBStateEncoder { * @param useColumnFamilies - if column family is enabled for this encoder */ class PrefixKeyScanStateEncoder( + stateEncoder: RocksDBDataEncoder, keySchema: StructType, numColsPrefixKey: Int, useColumnFamilies: Boolean = false, virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { - - import RocksDBStateEncoder._ + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { keySchema.zipWithIndex.take(numColsPrefixKey) @@ -210,8 +803,8 @@ class PrefixKeyScanStateEncoder( private val joinedRowOnKey = new JoinedRow() override def encodeKey(row: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row)) - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) + val prefixKeyEncoded = stateEncoder.encodeKey(extractPrefixKey(row)) + val remainingEncoded = stateEncoder.encodeRemainingKey(remainingKeyProjection(row)) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + remainingEncoded.length + 4 @@ -243,9 +836,9 @@ class PrefixKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - numColsPrefixKey) + val prefixKeyDecoded = stateEncoder.decodeKey( + prefixKeyEncoded) + val remainingKeyDecoded = stateEncoder.decodeRemainingKey(remainingKeyEncoded) restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) } @@ -255,7 +848,7 @@ class PrefixKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(prefixKey) + val prefixKeyEncoded = stateEncoder.encodeKey(prefixKey) val (prefix, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + 4 ) @@ -301,14 +894,13 @@ class PrefixKeyScanStateEncoder( * @param useColumnFamilies - if column family is enabled for this encoder */ class RangeKeyScanStateEncoder( + stateEncoder: RocksDBDataEncoder, keySchema: StructType, orderingOrdinals: Seq[Int], useColumnFamilies: Boolean = false, virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { - import RocksDBStateEncoder._ - private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { orderingOrdinals.map { ordinal => val field = keySchema(ordinal) @@ -381,195 +973,13 @@ class RangeKeyScanStateEncoder( rangeScanKeyProjection(key) } - // bit masks used for checking sign or flipping all bits for negative float/double values - private val floatFlipBitMask = 0xFFFFFFFF - private val floatSignBitMask = 0x80000000 - - private val doubleFlipBitMask = 0xFFFFFFFFFFFFFFFFL - private val doubleSignBitMask = 0x8000000000000000L - - // Byte markers used to identify whether the value is null, negative or positive - // To ensure sorted ordering, we use the lowest byte value for negative numbers followed by - // positive numbers and then null values. - private val negativeValMarker: Byte = 0x00.toByte - private val positiveValMarker: Byte = 0x01.toByte - private val nullValMarker: Byte = 0x02.toByte - - // Rewrite the unsafe row by replacing fixed size fields with BIG_ENDIAN encoding - // using byte arrays. - // To handle "null" values, we prepend a byte to the byte array indicating whether the value - // is null or not. If the value is null, we write the null byte followed by zero bytes. - // If the value is not null, we write the null byte followed by the value. - // Note that setting null for the index on the unsafeRow is not feasible as it would change - // the sorting order on iteration. - // Also note that the same byte is used to indicate whether the value is negative or not. - private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = { - val writer = new UnsafeRowWriter(orderingOrdinals.length) - writer.resetRowWriter() - rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => - val field = fieldWithOrdinal._1 - val value = row.get(idx, field.dataType) - // Note that we cannot allocate a smaller buffer here even if the value is null - // because the effective byte array is considered variable size and needs to have - // the same size across all rows for the ordering to work as expected. - val bbuf = ByteBuffer.allocate(field.dataType.defaultSize + 1) - bbuf.order(ByteOrder.BIG_ENDIAN) - if (value == null) { - bbuf.put(nullValMarker) - writer.write(idx, bbuf.array()) - } else { - field.dataType match { - case BooleanType => - case ByteType => - val byteVal = value.asInstanceOf[Byte] - val signCol = if (byteVal < 0) { - negativeValMarker - } else { - positiveValMarker - } - bbuf.put(signCol) - bbuf.put(byteVal) - writer.write(idx, bbuf.array()) - - case ShortType => - val shortVal = value.asInstanceOf[Short] - val signCol = if (shortVal < 0) { - negativeValMarker - } else { - positiveValMarker - } - bbuf.put(signCol) - bbuf.putShort(shortVal) - writer.write(idx, bbuf.array()) - - case IntegerType => - val intVal = value.asInstanceOf[Int] - val signCol = if (intVal < 0) { - negativeValMarker - } else { - positiveValMarker - } - bbuf.put(signCol) - bbuf.putInt(intVal) - writer.write(idx, bbuf.array()) - - case LongType => - val longVal = value.asInstanceOf[Long] - val signCol = if (longVal < 0) { - negativeValMarker - } else { - positiveValMarker - } - bbuf.put(signCol) - bbuf.putLong(longVal) - writer.write(idx, bbuf.array()) - - case FloatType => - val floatVal = value.asInstanceOf[Float] - val rawBits = floatToRawIntBits(floatVal) - // perform sign comparison using bit manipulation to ensure NaN values are handled - // correctly - if ((rawBits & floatSignBitMask) != 0) { - // for negative values, we need to flip all the bits to ensure correct ordering - val updatedVal = rawBits ^ floatFlipBitMask - bbuf.put(negativeValMarker) - // convert the bits back to float - bbuf.putFloat(intBitsToFloat(updatedVal)) - } else { - bbuf.put(positiveValMarker) - bbuf.putFloat(floatVal) - } - writer.write(idx, bbuf.array()) - - case DoubleType => - val doubleVal = value.asInstanceOf[Double] - val rawBits = doubleToRawLongBits(doubleVal) - // perform sign comparison using bit manipulation to ensure NaN values are handled - // correctly - if ((rawBits & doubleSignBitMask) != 0) { - // for negative values, we need to flip all the bits to ensure correct ordering - val updatedVal = rawBits ^ doubleFlipBitMask - bbuf.put(negativeValMarker) - // convert the bits back to double - bbuf.putDouble(longBitsToDouble(updatedVal)) - } else { - bbuf.put(positiveValMarker) - bbuf.putDouble(doubleVal) - } - writer.write(idx, bbuf.array()) - } - } - } - writer.getRow() - } - - // Rewrite the unsafe row by converting back from BIG_ENDIAN byte arrays to the - // original data types. - // For decode, we extract the byte array from the UnsafeRow, and then read the first byte - // to determine if the value is null or not. If the value is null, we set the ordinal on - // the UnsafeRow to null. If the value is not null, we read the rest of the bytes to get the - // actual value. - // For negative float/double values, we need to flip all the bits back to get the original value. - private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = { - val writer = new UnsafeRowWriter(orderingOrdinals.length) - writer.resetRowWriter() - rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => - val field = fieldWithOrdinal._1 - - val value = row.getBinary(idx) - val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]]) - bbuf.order(ByteOrder.BIG_ENDIAN) - val isNullOrSignCol = bbuf.get() - if (isNullOrSignCol == nullValMarker) { - // set the column to null and skip reading the next byte(s) - writer.setNullAt(idx) - } else { - field.dataType match { - case BooleanType => - case ByteType => - writer.write(idx, bbuf.get) - - case ShortType => - writer.write(idx, bbuf.getShort) - - case IntegerType => - writer.write(idx, bbuf.getInt) - - case LongType => - writer.write(idx, bbuf.getLong) - - case FloatType => - if (isNullOrSignCol == negativeValMarker) { - // if the number is negative, get the raw binary bits for the float - // and flip the bits back - val updatedVal = floatToRawIntBits(bbuf.getFloat) ^ floatFlipBitMask - writer.write(idx, intBitsToFloat(updatedVal)) - } else { - writer.write(idx, bbuf.getFloat) - } - - case DoubleType => - if (isNullOrSignCol == negativeValMarker) { - // if the number is negative, get the raw binary bits for the double - // and flip the bits back - val updatedVal = doubleToRawLongBits(bbuf.getDouble) ^ doubleFlipBitMask - writer.write(idx, longBitsToDouble(updatedVal)) - } else { - writer.write(idx, bbuf.getDouble) - } - } - } - } - writer.getRow() - } - override def encodeKey(row: UnsafeRow): Array[Byte] = { // This prefix key has the columns specified by orderingOrdinals val prefixKey = extractPrefixKey(row) - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = stateEncoder.encodePrefixKeyForRangeScan(prefixKey) val result = if (orderingOrdinals.length < keySchema.length) { - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) + val remainingEncoded = stateEncoder.encodeRemainingKey(remainingKeyProjection(row)) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( rangeScanKeyEncoded.length + remainingEncoded.length + 4 ) @@ -606,9 +1016,8 @@ class RangeKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4, prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) - val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded, - numFields = orderingOrdinals.length) - val prefixKeyDecoded = decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan) + val prefixKeyDecoded = stateEncoder.decodePrefixKeyForRangeScan( + prefixKeyEncoded) if (orderingOrdinals.length < keySchema.length) { // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes @@ -620,8 +1029,7 @@ class RangeKeyScanStateEncoder( remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - orderingOrdinals.length) + val remainingKeyDecoded = stateEncoder.decodeRemainingKey(remainingKeyEncoded) val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded) val restored = restoreKeyProjection(joined) @@ -634,7 +1042,7 @@ class RangeKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = stateEncoder.encodePrefixKeyForRangeScan(prefixKey) val (prefix, startingOffset) = encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4) Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length) @@ -659,21 +1067,20 @@ class RangeKeyScanStateEncoder( * then the generated array byte will be N+1 bytes. */ class NoPrefixKeyStateEncoder( + stateEncoder: RocksDBDataEncoder, keySchema: StructType, useColumnFamilies: Boolean = false, virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { - import RocksDBStateEncoder._ - // Reusable objects private val keyRow = new UnsafeRow(keySchema.size) override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { - encodeUnsafeRow(row) + stateEncoder.encodeUnsafeRow(row) } else { - val bytesToEncode = row.getBytes + val bytesToEncode = stateEncoder.encodeKey(row) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES @@ -697,15 +1104,13 @@ class NoPrefixKeyStateEncoder( if (useColumnFamilies) { if (keyBytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. - keyRow.pointTo( - keyBytes, - decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, - keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) - keyRow + stateEncoder.decodeKey(keyBytes) } else { null } - } else decodeToUnsafeRow(keyBytes, keyRow) + } else { + stateEncoder.decodeToUnsafeRow(keyBytes, keyRow) + } } override def supportPrefixKeyScan: Boolean = false @@ -728,16 +1133,13 @@ class NoPrefixKeyStateEncoder( * merged in RocksDB using merge operation, and all merged values can be read using decodeValues * operation. */ -class MultiValuedStateEncoder(valueSchema: StructType) +class MultiValuedStateEncoder( + stateEncoder: RocksDBDataEncoder, + valueSchema: StructType) extends RocksDBValueStateEncoder with Logging { - import RocksDBStateEncoder._ - - // Reusable objects - private val valueRow = new UnsafeRow(valueSchema.size) - override def encodeValue(row: UnsafeRow): Array[Byte] = { - val bytes = encodeUnsafeRow(row) + val bytes = stateEncoder.encodeValue(row) val numBytes = bytes.length val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length) @@ -756,7 +1158,7 @@ class MultiValuedStateEncoder(valueSchema: StructType) val encodedValue = new Array[Byte](numBytes) Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET, encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes) - decodeToUnsafeRow(encodedValue, valueRow) + stateEncoder.decodeValue(encodedValue) } } @@ -782,7 +1184,7 @@ class MultiValuedStateEncoder(valueSchema: StructType) pos += numBytes pos += 1 // eat the delimiter character - decodeToUnsafeRow(encodedValue, valueRow) + stateEncoder.decodeValue(encodedValue) } } } @@ -803,15 +1205,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, * then the generated array byte will be N+1 bytes. */ -class SingleValueStateEncoder(valueSchema: StructType) +class SingleValueStateEncoder( + stateEncoder: RocksDBDataEncoder, + valueSchema: StructType) extends RocksDBValueStateEncoder { - import RocksDBStateEncoder._ - - // Reusable objects - private val valueRow = new UnsafeRow(valueSchema.size) - - override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) + override def encodeValue(row: UnsafeRow): Array[Byte] = stateEncoder.encodeValue(row) /** * Decode byte array for a value to a UnsafeRow. @@ -820,7 +1219,7 @@ class SingleValueStateEncoder(valueSchema: StructType) * the given byte array. */ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { - decodeToUnsafeRow(valueBytes, valueRow) + stateEncoder.decodeValue(valueBytes) } override def supportsMultipleValuesPerKey: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1fc6ab5910c6c..e6507ba50a31e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.util.concurrent.ConcurrentHashMap +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.util.control.NonFatal +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -29,11 +31,12 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{NonFateSharingCache, Utils} private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable @@ -74,9 +77,15 @@ private[sql] class RocksDBStateStoreProvider isInternal: Boolean = false): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) + val dataEncoderCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + + s"${stateStoreId.partitionId}_${colFamilyName}" + + val dataEncoder = getDataEncoder( + stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) + keyValueEncoderMap.putIfAbsent(colFamilyName, - (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies, - Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(valueSchema, + (RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec, useColumnFamilies, + Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema, useMultipleValuesPerKey))) } @@ -364,6 +373,7 @@ private[sql] class RocksDBStateStoreProvider this.storeConf = storeConf this.hadoopConf = hadoopConf this.useColumnFamilies = useColumnFamilies + this.stateStoreEncoding = storeConf.stateStoreEncodingFormat if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -377,10 +387,16 @@ private[sql] class RocksDBStateStoreProvider defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME)) } + val dataEncoderCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + + s"${stateStoreId.partitionId}_${StateStore.DEFAULT_COL_FAMILY_NAME}" + + val dataEncoder = getDataEncoder( + stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) + keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, - (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, + (RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec, useColumnFamilies, defaultColFamilyId), - RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey))) + RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema, useMultipleValuesPerKey))) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -458,6 +474,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var useColumnFamilies: Boolean = _ + @volatile private var stateStoreEncoding: String = _ private[sql] lazy val rocksDB = { val dfsRootDir = stateStoreId.storeCheckpointLocation().toString @@ -593,6 +610,113 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 + private val MAX_AVRO_ENCODERS_IN_CACHE = 1000 + // Add the cache at companion object level so it persists across provider instances + private val dataEncoderCache: NonFateSharingCache[String, RocksDBDataEncoder] = { + val guavaCache = CacheBuilder.newBuilder() + .maximumSize(MAX_AVRO_ENCODERS_IN_CACHE) // Adjust size based on your needs + .expireAfterAccess(1, TimeUnit.HOURS) // Optional: Add expiration if needed + .build[String, RocksDBDataEncoder]() + + new NonFateSharingCache(guavaCache) + } + + def getDataEncoder( + stateStoreEncoding: String, + encoderCacheKey: String, + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType): RocksDBDataEncoder = { + + stateStoreEncoding match { + case "avro" => + RocksDBStateStoreProvider.dataEncoderCache.get( + encoderCacheKey, + new java.util.concurrent.Callable[AvroStateEncoder] { + override def call(): AvroStateEncoder = { + val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) + new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder) + } + } + ) + case "unsaferow" => + RocksDBStateStoreProvider.dataEncoderCache.get( + encoderCacheKey, + new java.util.concurrent.Callable[UnsafeRowStateEncoder] { + override def call(): UnsafeRowStateEncoder = { + new UnsafeRowStateEncoder(keyStateEncoderSpec, valueSchema) + } + } + ) + } + } + + private def getRunId(hadoopConf: Configuration): String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } + + private def getAvroSerializer(schema: StructType): AvroSerializer = { + val avroType = SchemaConverters.toAvroType(schema) + new AvroSerializer(schema, avroType, nullable = false) + } + + private def getAvroDeserializer(schema: StructType): AvroDeserializer = { + val avroType = SchemaConverters.toAvroType(schema) + val avroOptions = AvroOptions(Map.empty) + new AvroDeserializer(avroType, schema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + } + + private def createAvroEnc( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType + ): AvroEncoder = { + val valueSerializer = getAvroSerializer(valueSchema) + val valueDeserializer = getAvroDeserializer(valueSchema) + + // Get key schema based on encoder spec type + val keySchema = keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(schema) => + schema + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + StructType(schema.take(numColsPrefixKey)) + case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => + val remainingSchema = { + 0.until(schema.length).diff(orderingOrdinals).map { ordinal => + schema(ordinal) + } + } + StructType(remainingSchema) + } + + // Handle suffix key schema for prefix scan case + val suffixKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + Some(StructType(schema.drop(numColsPrefixKey))) + case _ => + None + } + + val keySerializer = getAvroSerializer(keySchema) + val keyDeserializer = getAvroDeserializer(keySchema) + + // Create the AvroEncoder with all components + AvroEncoder( + keySerializer, + keyDeserializer, + valueSerializer, + valueDeserializer, + suffixKeySchema.map(getAvroSerializer), + suffixKeySchema.map(getAvroDeserializer) + ) + } + // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 721d72b6a0991..69a29cdbe7a17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer} import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} @@ -37,6 +38,17 @@ case class StateSchemaValidationResult( schemaPath: String ) +// Avro encoder that is used by the RocksDBStateStoreProvider and RocksDBStateEncoder +// in order to serialize from UnsafeRow to a byte array of Avro encoding. +case class AvroEncoder( + keySerializer: AvroSerializer, + keyDeserializer: AvroDeserializer, + valueSerializer: AvroSerializer, + valueDeserializer: AvroDeserializer, + suffixKeySerializer: Option[AvroSerializer] = None, + suffixKeyDeserializer: Option[AvroDeserializer] = None +) extends Serializable + // Used to represent the schema of a column family in the state store case class StateStoreColFamilySchema( colFamilyName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 72bc3ca33054d..85cb4f65f1f04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -37,10 +37,22 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamExecution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} +sealed trait StateStoreEncoding { + override def toString: String = this match { + case StateStoreEncoding.UnsafeRow => "unsaferow" + case StateStoreEncoding.Avro => "avro" + } +} + +object StateStoreEncoding { + case object UnsafeRow extends StateStoreEncoding + case object Avro extends StateStoreEncoding +} + /** * Base trait for a versioned key-value store which provides read operations. Each instance of a * `ReadStateStore` represents a specific version of state data, and such instances are created @@ -310,6 +322,7 @@ case class StateStoreCustomTimingMetric(name: String, desc: String) extends Stat } sealed trait KeyStateEncoderSpec { + def keySchema: StructType def jsonValue: JValue def json: String = compact(render(jsonValue)) } @@ -746,6 +759,7 @@ object StateStore extends Logging { storeConf: StateStoreConf, hadoopConf: Configuration, useMultipleValuesPerKey: Boolean = false): ReadStateStore = { + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } @@ -766,6 +780,7 @@ object StateStore extends Logging { storeConf: StateStoreConf, hadoopConf: Configuration, useMultipleValuesPerKey: Boolean = false): StateStore = { + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index c8af395e996d8..8b373a5f658bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -83,6 +83,10 @@ class StateStoreConf( /** The interval of maintenance tasks. */ val maintenanceInterval = sqlConf.streamingMaintenanceInterval + + /** The interval of maintenance tasks. */ + val stateStoreEncodingFormat = sqlConf.stateStoreEncodingFormat + /** * When creating new state store checkpoint, which format version to use. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index baab6327b35c1..12a8ac44b1ab2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} import org.apache.spark.sql.functions.{col, explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} @@ -126,6 +126,7 @@ class SessionGroupsStatefulProcessorWithTTL extends * Test suite to verify integration of state data source reader with the transformWithState operator */ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest + with AlsoTestWithEncodingTypes with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e1bd9dd38066b..0abdcadefbd55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.util.Utils @ExtendedSQLTest class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider] with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes with SharedSparkSession with BeforeAndAfter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 637eb49130305..61ca8e7c32f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -86,6 +86,19 @@ trait RocksDBStateStoreChangelogCheckpointingTestUtil { } } +trait AlsoTestWithEncodingTypes extends SQLTestUtils { + override protected def test(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + Seq("unsaferow", "avro").foreach { encoding => + super.test(s"$testName (encoding = $encoding)", testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody + } + } + } + } +} + trait AlsoTestWithChangelogCheckpointingEnabled extends SQLTestUtils with RocksDBStateStoreChangelogCheckpointingTestUtil { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 55d08cd8f12a7..8984d9b0845b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -423,7 +423,7 @@ class ValueStateSuite extends StateVariableSuiteBase { * types (ValueState, ListState, MapState) used in arbitrary stateful operators. */ abstract class StateVariableSuiteBase extends SharedSparkSession - with BeforeAndAfter { + with BeforeAndAfter with AlsoTestWithEncodingTypes { before { StateStore.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 88862e2ad0791..5d88db0d01ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputRow(key: String, action: String, value: String) @@ -127,7 +127,8 @@ class ToggleSaveAndEmitProcessor } class TransformWithListStateSuite extends StreamTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ test("test appending null value in list state throw exception") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 76c5cbeee424b..253fd13395d64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputMapRow(key: String, action: String, value: (String, String)) @@ -81,7 +81,8 @@ class TestMapStateProcessor * operators such as transformWithState. */ class TransformWithMapStateSuite extends StreamTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala index 6888fcba45f3e..634f9b543905d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SparkRuntimeException, SparkThrowable} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamExecution} -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.functions.window import org.apache.spark.sql.internal.SQLConf @@ -103,7 +103,8 @@ case class AggEventRow( window: Window, count: Long) -class TransformWithStateChainingSuite extends StreamTest { +class TransformWithStateChainingSuite extends StreamTest + with AlsoTestWithEncodingTypes { import testImplicits._ test("watermark is propagated correctly for next stateful operator" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 505775d4f6a9b..1f0ea2be848d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -429,7 +429,8 @@ class SleepingTimerProcessor extends StatefulProcessor[String, String, String] { * Class that adds tests for transformWithState stateful streaming operator */ class TransformWithStateSuite extends StateStoreMetricsTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 2ddf69aa49e04..9a4618f8922f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,7 +41,8 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest { + extends StreamTest + with AlsoTestWithEncodingTypes { import testImplicits._ def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, InputEvent, OutputEvent] From 35414a887d3582618f6df3e257dfbe499b72ba3a Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 23 Nov 2024 14:24:23 -0800 Subject: [PATCH 02/17] fixed some stuff --- .../streaming/state/RocksDBStateEncoder.scala | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index d718ec54e6ecb..3d4ee7bbb60af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -254,7 +254,15 @@ class UnsafeRowStateEncoder( } } - override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = null + override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = { + keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) => + decodeToUnsafeRow(bytes, numFields = numColsPrefixKey) + case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) => + decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length) + case _ => null + } + } override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = { assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec]) @@ -405,7 +413,7 @@ class AvroStateEncoder( if (valueBytes != null) { val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( - valueBytes, Platform.BYTE_ARRAY_OFFSET, valueBytes.length, null) + valueBytes, 0, valueBytes.length, null) // bytes -> Avro.GenericDataRecord val genericData = reader.read(null, decoder) // Avro.GenericDataRecord -> InternalRow @@ -567,7 +575,8 @@ class AvroStateEncoder( case NoPrefixKeyStateEncoderSpec(_) => decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer, keyAvroType, keyProj) case PrefixKeyScanStateEncoderSpec(_, _) => - decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer, keyAvroType, prefixKeyProj) + decodeFromAvroToUnsafeRow( + bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj) case _ => null } } @@ -899,7 +908,7 @@ class RangeKeyScanStateEncoder( orderingOrdinals: Seq[Int], useColumnFamilies: Boolean = false, virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { orderingOrdinals.map { ordinal => @@ -1103,8 +1112,17 @@ class NoPrefixKeyStateEncoder( override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { if (useColumnFamilies) { if (keyBytes != null) { - // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. - stateEncoder.decodeKey(keyBytes) + // Create new byte array without prefix + val dataLength = keyBytes.length - + STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES + val dataBytes = new Array[Byte](dataLength) + Platform.copyMemory( + keyBytes, + decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + dataBytes, + Platform.BYTE_ARRAY_OFFSET, + dataLength) + stateEncoder.decodeKey(dataBytes) } else { null } From 19f34ab51548675fc25e0f984a5efcaac88af133 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 23 Nov 2024 14:40:40 -0800 Subject: [PATCH 03/17] renaming --- .../streaming/state/RocksDBStateEncoder.scala | 10 +++++----- .../streaming/state/RocksDBStateStoreProvider.scala | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 3d4ee7bbb60af..ea9590a8921dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -51,7 +51,7 @@ sealed trait RocksDBValueStateEncoder { def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] } -trait StateEncoder { +trait DataEncoder { def encodeKey(row: UnsafeRow): Array[Byte] def encodeRemainingKey(row: UnsafeRow): Array[Byte] def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] @@ -65,7 +65,7 @@ trait StateEncoder { abstract class RocksDBDataEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, - valueSchema: StructType) extends StateEncoder { + valueSchema: StructType) extends DataEncoder { val keySchema = keyStateEncoderSpec.keySchema val reusedKeyRow = new UnsafeRow(keyStateEncoderSpec.keySchema.length) @@ -124,7 +124,7 @@ abstract class RocksDBDataEncoder( } } -class UnsafeRowStateEncoder( +class UnsafeRowDataEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) { @@ -1087,7 +1087,7 @@ class NoPrefixKeyStateEncoder( override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { - stateEncoder.encodeUnsafeRow(row) + stateEncoder.encodeKey(row) } else { val bytesToEncode = stateEncoder.encodeKey(row) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( @@ -1127,7 +1127,7 @@ class NoPrefixKeyStateEncoder( null } } else { - stateEncoder.decodeToUnsafeRow(keyBytes, keyRow) + stateEncoder.decodeKey(keyBytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e6507ba50a31e..791727d5210d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -641,9 +641,9 @@ object RocksDBStateStoreProvider { case "unsaferow" => RocksDBStateStoreProvider.dataEncoderCache.get( encoderCacheKey, - new java.util.concurrent.Callable[UnsafeRowStateEncoder] { - override def call(): UnsafeRowStateEncoder = { - new UnsafeRowStateEncoder(keyStateEncoderSpec, valueSchema) + new java.util.concurrent.Callable[UnsafeRowDataEncoder] { + override def call(): UnsafeRowDataEncoder = { + new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema) } } ) From 94acc5d9398b1b0e3e48132e03f69db068a77072 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 23 Nov 2024 15:14:03 -0800 Subject: [PATCH 04/17] scaladoc --- .../streaming/state/RocksDBStateEncoder.scala | 124 +++++++++++++----- .../state/RocksDBStateStoreProvider.scala | 37 +++++- 2 files changed, 123 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index ea9590a8921dc..b7eb95644c405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -51,6 +51,15 @@ sealed trait RocksDBValueStateEncoder { def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] } +/** + * The DataEncoder can encode UnsafeRows into raw bytes in two ways: + * - Using the direct byte layout of the UnsafeRow + * - Converting the UnsafeRow into an Avro row, and encoding that + * In both of these cases, the raw bytes that are written into RockDB have + * headers, footers and other metadata, but they also have data that is provided + * by the callers. The metadata in each row does not need to be written as Avro or UnsafeRow, + * but the actual data provided by the caller does. + */ trait DataEncoder { def encodeKey(row: UnsafeRow): Array[Byte] def encodeRemainingKey(row: UnsafeRow): Array[Byte] @@ -333,11 +342,12 @@ class AvroStateEncoder( avroEncoder: AvroEncoder) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) with Logging { + // Avro schema used by the avro encoders private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema) private lazy val keyProj = UnsafeProjection.create(keySchema) - private lazy val valueProj = UnsafeProjection.create(valueSchema) private lazy val valueAvroType: Schema = SchemaConverters.toAvroType(valueSchema) + private lazy val valueProj = UnsafeProjection.create(valueSchema) // Prefix Key schema and projection definitions used by the Avro Serializers // and Deserializers @@ -346,10 +356,11 @@ class AvroStateEncoder( StructType(keySchema.take (numColsPrefixKey)) case _ => null } - private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) private lazy val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) + // Range Key schema nd projection definitions used by the Avro Serializers and + // Deserializers private lazy val rangeScanKeyFieldsWithOrdinal = keyStateEncoderSpec match { case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => orderingOrdinals.map { ordinal => @@ -366,7 +377,7 @@ class AvroStateEncoder( private lazy val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema) - // Existing remainder key schema stuff + // Existing remainder key schema definitions // Remaining Key schema and projection definitions used by the Avro Serializers // and Deserializers private val remainingKeySchema = keyStateEncoderSpec match { @@ -427,6 +438,7 @@ class AvroStateEncoder( } private val out = new ByteArrayOutputStream + override def encodeKey(row: UnsafeRow): Array[Byte] = { keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(_) => @@ -447,8 +459,27 @@ class AvroStateEncoder( } } - override def encodePrefixKeyForRangeScan( - row: UnsafeRow): Array[Byte] = { + /** + * Encodes an UnsafeRow into an Avro-compatible byte array format for range scan operations. + * + * This method transforms row data into a binary format that preserves ordering when + * used in range scans. + * For each field in the row: + * - A marker byte is written to indicate null status or sign (for numeric types) + * - The value is written in big-endian format + * + * Special handling is implemented for: + * - Null values: marked with nullValMarker followed by zero bytes + * - Negative numbers: marked with negativeValMarker + * - Floating point numbers: bit manipulation to handle sign and NaN values correctly + * + * @param row The UnsafeRow to encode + * @param avroType The Avro schema defining the structure for encoding + * @return Array[Byte] containing the Avro-encoded data that preserves ordering for range scans + * @throws UnsupportedOperationException if a field's data type is not supported for range + * scan encoding + */ + override def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] = { val record = new GenericData.Record(rangeScanAvroType) var fieldIdx = 0 rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => @@ -594,6 +625,25 @@ class AvroStateEncoder( } } + /** + * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan operations. + * + * This method reverses the encoding process performed by encodePrefixKeyForRangeScan: + * - Reads the marker byte to determine null status or sign + * - Reconstructs the original values from big-endian format + * - Handles special cases for floating point numbers by reversing bit manipulations + * + * The decoding process preserves the original data types and values, including: + * - Null values marked by nullValMarker + * - Sign information for numeric types + * - Proper restoration of negative floating point values + * + * @param bytes The Avro-encoded byte array to decode + * @param avroType The Avro schema defining the structure for decoding + * @return UnsafeRow containing the decoded data + * @throws UnsupportedOperationException if a field's data type is not supported for range + * scan decoding + */ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = { val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType) val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) @@ -730,21 +780,21 @@ abstract class RocksDBKeyStateEncoderBase( object RocksDBStateEncoder { def getKeyEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, keyStateEncoderSpec: KeyStateEncoderSpec, useColumnFamilies: Boolean, virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = { // Return the key state encoder based on the requested type keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(keySchema) => - new NoPrefixKeyStateEncoder(stateEncoder, keySchema, useColumnFamilies, virtualColFamilyId) + new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId) case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => - new PrefixKeyScanStateEncoder(stateEncoder, keySchema, numColsPrefixKey, + new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies, virtualColFamilyId) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => - new RangeKeyScanStateEncoder(stateEncoder, keySchema, orderingOrdinals, + new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals, useColumnFamilies, virtualColFamilyId) case _ => @@ -754,13 +804,13 @@ object RocksDBStateEncoder { } def getValueEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, valueSchema: StructType, useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = { if (useMultipleValuesPerKey) { - new MultiValuedStateEncoder(stateEncoder, valueSchema) + new MultiValuedStateEncoder(dataEncoder, valueSchema) } else { - new SingleValueStateEncoder(stateEncoder, valueSchema) + new SingleValueStateEncoder(dataEncoder, valueSchema) } } @@ -774,12 +824,13 @@ object RocksDBStateEncoder { /** * RocksDB Key Encoder for UnsafeRow that supports prefix scan * + * @param dataEncoder - the encoder that handles actual encoding/decoding of data * @param keySchema - schema of the key to be encoded * @param numColsPrefixKey - number of columns to be used for prefix key * @param useColumnFamilies - if column family is enabled for this encoder */ class PrefixKeyScanStateEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, keySchema: StructType, numColsPrefixKey: Int, useColumnFamilies: Boolean = false, @@ -812,8 +863,8 @@ class PrefixKeyScanStateEncoder( private val joinedRowOnKey = new JoinedRow() override def encodeKey(row: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = stateEncoder.encodeKey(extractPrefixKey(row)) - val remainingEncoded = stateEncoder.encodeRemainingKey(remainingKeyProjection(row)) + val prefixKeyEncoded = dataEncoder.encodeKey(extractPrefixKey(row)) + val remainingEncoded = dataEncoder.encodeRemainingKey(remainingKeyProjection(row)) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + remainingEncoded.length + 4 @@ -845,9 +896,9 @@ class PrefixKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val prefixKeyDecoded = stateEncoder.decodeKey( + val prefixKeyDecoded = dataEncoder.decodeKey( prefixKeyEncoded) - val remainingKeyDecoded = stateEncoder.decodeRemainingKey(remainingKeyEncoded) + val remainingKeyDecoded = dataEncoder.decodeRemainingKey(remainingKeyEncoded) restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) } @@ -857,7 +908,7 @@ class PrefixKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = stateEncoder.encodeKey(prefixKey) + val prefixKeyEncoded = dataEncoder.encodeKey(prefixKey) val (prefix, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + 4 ) @@ -898,12 +949,13 @@ class PrefixKeyScanStateEncoder( * the right lexicographical ordering. For the rationale around this, please check the link * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale * + * @param dataEncoder - the encoder that handles the actual encoding/decoding of data * @param keySchema - schema of the key to be encoded * @param orderingOrdinals - the ordinals for which the range scan is constructed * @param useColumnFamilies - if column family is enabled for this encoder */ class RangeKeyScanStateEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, keySchema: StructType, orderingOrdinals: Seq[Int], useColumnFamilies: Boolean = false, @@ -985,10 +1037,10 @@ class RangeKeyScanStateEncoder( override def encodeKey(row: UnsafeRow): Array[Byte] = { // This prefix key has the columns specified by orderingOrdinals val prefixKey = extractPrefixKey(row) - val rangeScanKeyEncoded = stateEncoder.encodePrefixKeyForRangeScan(prefixKey) + val rangeScanKeyEncoded = dataEncoder.encodePrefixKeyForRangeScan(prefixKey) val result = if (orderingOrdinals.length < keySchema.length) { - val remainingEncoded = stateEncoder.encodeRemainingKey(remainingKeyProjection(row)) + val remainingEncoded = dataEncoder.encodeRemainingKey(remainingKeyProjection(row)) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( rangeScanKeyEncoded.length + remainingEncoded.length + 4 ) @@ -1025,7 +1077,7 @@ class RangeKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4, prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) - val prefixKeyDecoded = stateEncoder.decodePrefixKeyForRangeScan( + val prefixKeyDecoded = dataEncoder.decodePrefixKeyForRangeScan( prefixKeyEncoded) if (orderingOrdinals.length < keySchema.length) { @@ -1038,7 +1090,7 @@ class RangeKeyScanStateEncoder( remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val remainingKeyDecoded = stateEncoder.decodeRemainingKey(remainingKeyEncoded) + val remainingKeyDecoded = dataEncoder.decodeRemainingKey(remainingKeyEncoded) val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded) val restored = restoreKeyProjection(joined) @@ -1051,7 +1103,7 @@ class RangeKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val rangeScanKeyEncoded = stateEncoder.encodePrefixKeyForRangeScan(prefixKey) + val rangeScanKeyEncoded = dataEncoder.encodePrefixKeyForRangeScan(prefixKey) val (prefix, startingOffset) = encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4) Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length) @@ -1076,7 +1128,7 @@ class RangeKeyScanStateEncoder( * then the generated array byte will be N+1 bytes. */ class NoPrefixKeyStateEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, keySchema: StructType, useColumnFamilies: Boolean = false, virtualColFamilyId: Option[Short] = None) @@ -1087,9 +1139,9 @@ class NoPrefixKeyStateEncoder( override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { - stateEncoder.encodeKey(row) + dataEncoder.encodeKey(row) } else { - val bytesToEncode = stateEncoder.encodeKey(row) + val bytesToEncode = dataEncoder.encodeKey(row) val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES @@ -1122,12 +1174,12 @@ class NoPrefixKeyStateEncoder( dataBytes, Platform.BYTE_ARRAY_OFFSET, dataLength) - stateEncoder.decodeKey(dataBytes) + dataEncoder.decodeKey(dataBytes) } else { null } } else { - stateEncoder.decodeKey(keyBytes) + dataEncoder.decodeKey(keyBytes) } } @@ -1152,12 +1204,12 @@ class NoPrefixKeyStateEncoder( * operation. */ class MultiValuedStateEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, valueSchema: StructType) extends RocksDBValueStateEncoder with Logging { override def encodeValue(row: UnsafeRow): Array[Byte] = { - val bytes = stateEncoder.encodeValue(row) + val bytes = dataEncoder.encodeValue(row) val numBytes = bytes.length val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length) @@ -1176,7 +1228,7 @@ class MultiValuedStateEncoder( val encodedValue = new Array[Byte](numBytes) Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET, encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes) - stateEncoder.decodeValue(encodedValue) + dataEncoder.decodeValue(encodedValue) } } @@ -1202,7 +1254,7 @@ class MultiValuedStateEncoder( pos += numBytes pos += 1 // eat the delimiter character - stateEncoder.decodeValue(encodedValue) + dataEncoder.decodeValue(encodedValue) } } } @@ -1224,11 +1276,11 @@ class MultiValuedStateEncoder( * then the generated array byte will be N+1 bytes. */ class SingleValueStateEncoder( - stateEncoder: RocksDBDataEncoder, + dataEncoder: RocksDBDataEncoder, valueSchema: StructType) extends RocksDBValueStateEncoder { - override def encodeValue(row: UnsafeRow): Array[Byte] = stateEncoder.encodeValue(row) + override def encodeValue(row: UnsafeRow): Array[Byte] = dataEncoder.encodeValue(row) /** * Decode byte array for a value to a UnsafeRow. @@ -1237,7 +1289,7 @@ class SingleValueStateEncoder( * the given byte array. */ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { - stateEncoder.decodeValue(valueBytes) + dataEncoder.decodeValue(valueBytes) } override def supportsMultipleValuesPerKey: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 791727d5210d2..8183a9a1cf7a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -611,6 +611,7 @@ object RocksDBStateStoreProvider { val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 private val MAX_AVRO_ENCODERS_IN_CACHE = 1000 + // Add the cache at companion object level so it persists across provider instances private val dataEncoderCache: NonFateSharingCache[String, RocksDBDataEncoder] = { val guavaCache = CacheBuilder.newBuilder() @@ -621,6 +622,23 @@ object RocksDBStateStoreProvider { new NonFateSharingCache(guavaCache) } + /** + * Creates and returns a data encoder for the state store based on the specified encoding type. + * This method handles caching of encoders to improve performance by reusing encoder instances + * when possible. + * + * The method supports two encoding types: + * - Avro: Uses Apache Avro for serialization with schema evolution support + * - UnsafeRow: Uses Spark's internal row format for optimal performance + * + * @param stateStoreEncoding The encoding type to use ("avro" or "unsaferow") + * @param encoderCacheKey A unique key for caching the encoder instance, typically combining + * query ID, operator ID, partition ID, and column family name + * @param keyStateEncoderSpec Specification for how to encode keys, including schema and any + * prefix/range scan requirements + * @param valueSchema The schema for the values to be encoded + * @return A RocksDBDataEncoder instance configured for the specified encoding type + */ def getDataEncoder( stateStoreEncoding: String, encoderCacheKey: String, @@ -628,7 +646,7 @@ object RocksDBStateStoreProvider { valueSchema: StructType): RocksDBDataEncoder = { stateStoreEncoding match { - case "avro" => + case StateStoreEncoding.Avro.toString => RocksDBStateStoreProvider.dataEncoderCache.get( encoderCacheKey, new java.util.concurrent.Callable[AvroStateEncoder] { @@ -638,7 +656,7 @@ object RocksDBStateStoreProvider { } } ) - case "unsaferow" => + case StateStoreEncoding.UnsafeRow.toString => RocksDBStateStoreProvider.dataEncoderCache.get( encoderCacheKey, new java.util.concurrent.Callable[UnsafeRowDataEncoder] { @@ -673,6 +691,21 @@ object RocksDBStateStoreProvider { avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) } + /** + * Creates an AvroEncoder that handles both key and value serialization/deserialization. + * This method sets up the complete encoding infrastructure needed for state store operations. + * + * The encoder handles different key encoding specifications: + * - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix + * - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning + * - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range scans + * + * For prefix scan cases, it also creates separate encoders for the suffix portion of keys. + * + * @param keyStateEncoderSpec Specification for how to encode keys + * @param valueSchema Schema for the values to be encoded + * @return An AvroEncoder containing all necessary serializers and deserializers + */ private def createAvroEnc( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType From 5d5cb6520c49bd7118dc9a645c5902db289deaed Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 23 Nov 2024 15:20:43 -0800 Subject: [PATCH 05/17] more stuff --- .../streaming/state/RocksDBStateEncoder.scala | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index b7eb95644c405..fb9f828fb777e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -94,6 +94,15 @@ abstract class RocksDBDataEncoder( val positiveValMarker: Byte = 0x01.toByte val nullValMarker: Byte = 0x02.toByte + + def unsupportedOperationForKeyStateEncoder( + operation: String + ): UnsupportedOperationException = { + new UnsupportedOperationException( + s"Method $operation not supported for encoder spec type " + + s"${keyStateEncoderSpec.getClass.getSimpleName}") + } + /** * Encode the UnsafeRow of N bytes as a N+1 byte array. * @note This creates a new byte array and memcopies the UnsafeRow to the new array. @@ -259,7 +268,7 @@ class UnsafeRowDataEncoder( decodeToUnsafeRow(bytes, reusedKeyRow) case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) => decodeToUnsafeRow(bytes, numFields = numColsPrefixKey) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey") } } @@ -269,7 +278,7 @@ class UnsafeRowDataEncoder( decodeToUnsafeRow(bytes, numFields = numColsPrefixKey) case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) => decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("decodeRemainingKey") } } @@ -354,7 +363,7 @@ class AvroStateEncoder( private lazy val prefixKeySchema = keyStateEncoderSpec match { case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => StructType(keySchema.take (numColsPrefixKey)) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("prefixKeySchema") } private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) private lazy val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) @@ -367,7 +376,8 @@ class AvroStateEncoder( val field = keySchema(ordinal) (field, ordinal) } - case _ => null + case _ => + throw unsupportedOperationForKeyStateEncoder("rangeScanKey") } private lazy val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( @@ -385,7 +395,7 @@ class AvroStateEncoder( StructType(keySchema.drop(numColsPrefixKey)) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => StructType(0.until(keySchema.length).diff(orderingOrdinals).map(keySchema(_))) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("remainingKeySchema") } private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) @@ -445,7 +455,7 @@ class AvroStateEncoder( encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out) case PrefixKeyScanStateEncoderSpec(_, _) => encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey") } } @@ -455,7 +465,7 @@ class AvroStateEncoder( encodeUnsafeRowToAvro(row, avroEncoder.suffixKeySerializer.get, remainingKeyAvroType, out) case RangeKeyScanStateEncoderSpec(_, _) => encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, remainingKeyAvroType, out) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("encodeRemainingKey") } } @@ -608,7 +618,7 @@ class AvroStateEncoder( case PrefixKeyScanStateEncoderSpec(_, _) => decodeFromAvroToUnsafeRow( bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey") } } @@ -621,7 +631,7 @@ class AvroStateEncoder( case RangeKeyScanStateEncoderSpec(_, _) => decodeFromAvroToUnsafeRow( bytes, avroEncoder.keyDeserializer, remainingKeyAvroType, remainingKeyAvroProjection) - case _ => null + case _ => throw unsupportedOperationForKeyStateEncoder("decodeRemainingKey") } } From 8d860eeaf218753d31ad14f4682033cbfa82701f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 25 Nov 2024 13:06:17 -0800 Subject: [PATCH 06/17] scalastyle --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8183a9a1cf7a8..0e82f34d0e9a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -646,7 +646,7 @@ object RocksDBStateStoreProvider { valueSchema: StructType): RocksDBDataEncoder = { stateStoreEncoding match { - case StateStoreEncoding.Avro.toString => + case "avro" => RocksDBStateStoreProvider.dataEncoderCache.get( encoderCacheKey, new java.util.concurrent.Callable[AvroStateEncoder] { @@ -656,7 +656,7 @@ object RocksDBStateStoreProvider { } } ) - case StateStoreEncoding.UnsafeRow.toString => + case "unsaferow" => RocksDBStateStoreProvider.dataEncoderCache.get( encoderCacheKey, new java.util.concurrent.Callable[UnsafeRowDataEncoder] { From d4f8b31fa893fb8e90aca5c20394e26608308092 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 25 Nov 2024 16:42:10 -0800 Subject: [PATCH 07/17] lazy val --- .../sql/execution/streaming/state/RocksDBStateEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index fb9f828fb777e..34b0d9a31d61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -390,7 +390,7 @@ class AvroStateEncoder( // Existing remainder key schema definitions // Remaining Key schema and projection definitions used by the Avro Serializers // and Deserializers - private val remainingKeySchema = keyStateEncoderSpec match { + private lazy val remainingKeySchema = keyStateEncoderSpec match { case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => StructType(keySchema.drop(numColsPrefixKey)) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => From ca1353c117ef201d94b6b908bfa6463dee4f9c37 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 13 Dec 2024 10:27:02 -0800 Subject: [PATCH 08/17] compiles --- .../sql/execution/streaming/state/RocksDBStateEncoder.scala | 1 - .../execution/streaming/state/RocksDBStateStoreProvider.scala | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 33b41b25ee048..30e4539ef7139 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -852,7 +852,6 @@ class PrefixKeyScanStateEncoder( virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { - private val usingAvroEncoding = avroEnc.isDefined private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { keySchema.zipWithIndex.take(numColsPrefixKey) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 669f49fad7e1e..e945032c1166a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -23,7 +23,6 @@ import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.util.control.NonFatal -import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -614,7 +613,7 @@ object RocksDBStateStoreProvider { private val AVRO_ENCODER_LIFETIME_HOURS = 1L // Add the cache at companion object level so it persists across provider instances - private val avroEncoderMap: NonFateSharingCache[String, AvroEncoder] = + private val dataEncoderCache: NonFateSharingCache[String, RocksDBDataEncoder] = NonFateSharingCache( maximumSize = MAX_AVRO_ENCODERS_IN_CACHE, expireAfterAccessTime = AVRO_ENCODER_LIFETIME_HOURS, From 85aa8da871835aea851c9b47d89b81d3a37e19a9 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 13 Dec 2024 13:26:42 -0800 Subject: [PATCH 09/17] feedback --- .../streaming/state/RocksDBStateEncoder.scala | 162 ++++++++++++------ .../state/RocksDBStateStoreProvider.scala | 66 +++---- .../streaming/state/StateStore.scala | 36 ++++ .../streaming/TransformWithStateTTLTest.scala | 4 +- 4 files changed, 187 insertions(+), 81 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 30e4539ef7139..c3d541a29ab3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -60,18 +60,94 @@ sealed trait RocksDBValueStateEncoder { * by the callers. The metadata in each row does not need to be written as Avro or UnsafeRow, * but the actual data provided by the caller does. */ +/** Interface for encoding and decoding state store data between UnsafeRow and raw bytes. + * + * @note All encode methods expect non-null input rows. Handling of null values is left to the + * implementing classes. + */ trait DataEncoder { + /** Encodes a complete key row into bytes. Used as the primary key for state lookups. + * + * @param row An UnsafeRow containing all key columns as defined in the keySchema + * @return Serialized byte array representation of the key + */ def encodeKey(row: UnsafeRow): Array[Byte] + + /** Encodes the non-prefix portion of a key row. Used with prefix scan and + * range scan state lookups where the key is split into prefix and remaining portions. + * + * For prefix scans: Encodes columns after the prefix columns + * For range scans: Encodes columns not included in the ordering columns + * + * @param row An UnsafeRow containing only the remaining key columns + * @return Serialized byte array of the remaining key portion + * @throws UnsupportedOperationException if called on an encoder that doesn't support split keys + */ def encodeRemainingKey(row: UnsafeRow): Array[Byte] + + /** Encodes key columns used for range scanning, ensuring proper sort order in RocksDB. + * + * This method handles special encoding for numeric types to maintain correct sort order: + * - Adds sign byte markers for numeric types + * - Flips bits for negative floating point values + * - Preserves null ordering + * + * @param row An UnsafeRow containing the columns needed for range scan + * (specified by orderingOrdinals) + * @return Serialized bytes that will maintain correct sort order in RocksDB + * @throws UnsupportedOperationException if called on an encoder that doesn't support range scans + */ def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] + + /** Encodes a value row into bytes. + * + * @param row An UnsafeRow containing the value columns as defined in the valueSchema + * @return Serialized byte array representation of the value + */ def encodeValue(row: UnsafeRow): Array[Byte] + /** Decodes a complete key from its serialized byte form. + * + * For NoPrefixKeyStateEncoder: Decodes the entire key + * For PrefixKeyScanStateEncoder: Decodes only the prefix portion + * + * @param bytes Serialized byte array containing the encoded key + * @return UnsafeRow containing the decoded key columns + * @throws UnsupportedOperationException for unsupported encoder types + */ def decodeKey(bytes: Array[Byte]): UnsafeRow + + /** Decodes the remaining portion of a split key from its serialized form. + * + * For PrefixKeyScanStateEncoder: Decodes columns after the prefix + * For RangeKeyScanStateEncoder: Decodes non-ordering columns + * + * @param bytes Serialized byte array containing the encoded remaining key portion + * @return UnsafeRow containing the decoded remaining key columns + * @throws UnsupportedOperationException if called on an encoder that doesn't support split keys + */ def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow + + /** Decodes range scan key bytes back into an UnsafeRow, preserving proper ordering. + * + * This method reverses the special encoding done by encodePrefixKeyForRangeScan: + * - Interprets sign byte markers + * - Reverses bit flipping for negative floating point values + * - Handles null values + * + * @param bytes Serialized byte array containing the encoded range scan key + * @return UnsafeRow containing the decoded range scan columns + * @throws UnsupportedOperationException if called on an encoder that doesn't support range scans + */ def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow + + /** Decodes a value from its serialized byte form. + * + * @param bytes Serialized byte array containing the encoded value + * @return UnsafeRow containing the decoded value columns + */ def decodeValue(bytes: Array[Byte]): UnsafeRow } - abstract class RocksDBDataEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType) extends DataEncoder { @@ -789,37 +865,43 @@ abstract class RocksDBKeyStateEncoderBase( } } +/** Factory object for creating state encoders used by RocksDB state store. + * + * The encoders created by this object handle serialization and deserialization of state data, + * supporting both key and value encoding with various access patterns + * (e.g., prefix scan, range scan). + */ object RocksDBStateEncoder extends Logging { + + /** Creates a key encoder based on the specified encoding strategy and configuration. + * + * @param dataEncoder The underlying encoder that handles the actual data encoding/decoding + * @param keyStateEncoderSpec Specification defining the key encoding strategy + * (no prefix, prefix scan, or range scan) + * @param useColumnFamilies Whether to use RocksDB column families for storage + * @param virtualColFamilyId Optional column family identifier when column families are enabled + * @return A configured RocksDBKeyStateEncoder instance + */ def getKeyEncoder( dataEncoder: RocksDBDataEncoder, keyStateEncoderSpec: KeyStateEncoderSpec, useColumnFamilies: Boolean, - virtualColFamilyId: Option[Short] = None, - avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = { - // Return the key state encoder based on the requested type - keyStateEncoderSpec match { - case NoPrefixKeyStateEncoderSpec(keySchema) => - new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId) - - case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => - new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey, - useColumnFamilies, virtualColFamilyId) - - case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => - new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals, - useColumnFamilies, virtualColFamilyId) - - case _ => - throw new IllegalArgumentException(s"Unsupported key state encoder spec: " + - s"$keyStateEncoderSpec") - } + virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = { + keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies, virtualColFamilyId) } + /** Creates a value encoder that supports either single or multiple values per key. + * + * @param dataEncoder The underlying encoder that handles the actual data encoding/decoding + * @param valueSchema Schema defining the structure of values to be encoded + * @param useMultipleValuesPerKey If true, creates an encoder that can handle multiple values + * per key; if false, creates an encoder for single values + * @return A configured RocksDBValueStateEncoder instance + */ def getValueEncoder( dataEncoder: RocksDBDataEncoder, valueSchema: StructType, - useMultipleValuesPerKey: Boolean, - avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = { + useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = { if (useMultipleValuesPerKey) { new MultiValuedStateEncoder(dataEncoder, valueSchema) } else { @@ -827,6 +909,14 @@ object RocksDBStateEncoder extends Logging { } } + /** Encodes a virtual column family ID into a byte array suitable for RocksDB. + * + * This method creates a fixed-size byte array prefixed with the virtual column family ID, + * which is used to partition data within RocksDB. + * + * @param virtualColFamilyId The column family identifier to encode + * @return A byte array containing the encoded column family ID + */ def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = { val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES) Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId) @@ -871,18 +961,6 @@ class PrefixKeyScanStateEncoder( UnsafeProjection.create(refs) } - // Prefix Key schema and projection definitions used by the Avro Serializers - // and Deserializers - private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey)) - private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) - private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) - - // Remaining Key schema and projection definitions used by the Avro Serializers - // and Deserializers - private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey)) - private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) - private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema) - // This is quite simple to do - just bind sequentially, as we don't change the order. private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema) @@ -1056,22 +1134,6 @@ class RangeKeyScanStateEncoder( UnsafeProjection.create(refs) } - private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( - StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray)) - - private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema) - - private val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema) - - // Existing remainder key schema stuff - private val remainingKeySchema = StructType( - 0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_)) - ) - - private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) - - private val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) - // Reusable objects private val joinedRowOnKey = new JoinedRow() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e945032c1166a..1bc74dd4057c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -82,10 +82,18 @@ private[sql] class RocksDBStateStoreProvider val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) - keyValueEncoderMap.putIfAbsent(colFamilyName, - (RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec, useColumnFamilies, - Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema, - useMultipleValuesPerKey))) + val keyEncoder = RocksDBStateEncoder.getKeyEncoder( + dataEncoder, + keyStateEncoderSpec, + useColumnFamilies, + Some(newColFamilyId) + ) + val valueEncoder = RocksDBStateEncoder.getValueEncoder( + dataEncoder, + valueSchema, + useMultipleValuesPerKey + ) + keyValueEncoderMap.putIfAbsent(colFamilyName, (keyEncoder, valueEncoder)) } override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { @@ -392,10 +400,18 @@ private[sql] class RocksDBStateStoreProvider val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) - keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, - (RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec, - useColumnFamilies, defaultColFamilyId), - RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema, useMultipleValuesPerKey))) + val keyEncoder = RocksDBStateEncoder.getKeyEncoder( + dataEncoder, + keyStateEncoderSpec, + useColumnFamilies, + defaultColFamilyId + ) + val valueEncoder = RocksDBStateEncoder.getValueEncoder( + dataEncoder, + valueSchema, + useMultipleValuesPerKey + ) + keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, (keyEncoder, valueEncoder)) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -642,28 +658,20 @@ object RocksDBStateStoreProvider { encoderCacheKey: String, keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType): RocksDBDataEncoder = { - - stateStoreEncoding match { - case "avro" => - RocksDBStateStoreProvider.dataEncoderCache.get( - encoderCacheKey, - new java.util.concurrent.Callable[AvroStateEncoder] { - override def call(): AvroStateEncoder = { - val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) - new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder) - } - } - ) - case "unsaferow" => - RocksDBStateStoreProvider.dataEncoderCache.get( - encoderCacheKey, - new java.util.concurrent.Callable[UnsafeRowDataEncoder] { - override def call(): UnsafeRowDataEncoder = { - new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema) - } + assert(Set("avro", "unsaferow").contains(stateStoreEncoding)) + RocksDBStateStoreProvider.dataEncoderCache.get( + encoderCacheKey, + new java.util.concurrent.Callable[RocksDBDataEncoder] { + override def call(): RocksDBDataEncoder = { + if (stateStoreEncoding == "avro") { + val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) + new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder) + } else { + new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema) } - ) - } + } + } + ) } private def getRunId(hadoopConf: Configuration): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index fa7946547bf68..ee1c558cd4bfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -325,6 +325,18 @@ sealed trait KeyStateEncoderSpec { def keySchema: StructType def jsonValue: JValue def json: String = compact(render(jsonValue)) + + /** Creates a RocksDBKeyStateEncoder for this specification. + * + * @param dataEncoder The encoder to handle the actual data encoding/decoding + * @param useColumnFamilies Whether to use RocksDB column families + * @param virtualColFamilyId Optional column family ID when column families are used + * @return A RocksDBKeyStateEncoder configured for this spec + */ + def toEncoder( + dataEncoder: RocksDBDataEncoder, + useColumnFamilies: Boolean, + virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder } object KeyStateEncoderSpec { @@ -348,6 +360,14 @@ case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEn override def jsonValue: JValue = { ("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec")) } + + override def toEncoder( + dataEncoder: RocksDBDataEncoder, + useColumnFamilies: Boolean, + virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = { + new NoPrefixKeyStateEncoder( + dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId) + } } case class PrefixKeyScanStateEncoderSpec( @@ -356,6 +376,14 @@ case class PrefixKeyScanStateEncoderSpec( if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString) } + override def toEncoder( + dataEncoder: RocksDBDataEncoder, + useColumnFamilies: Boolean, + virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = { + new PrefixKeyScanStateEncoder( + dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies, virtualColFamilyId) + } + override def jsonValue: JValue = { ("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~ @@ -371,6 +399,14 @@ case class RangeKeyScanStateEncoderSpec( throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString) } + override def toEncoder( + dataEncoder: RocksDBDataEncoder, + useColumnFamilies: Boolean, + virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = { + new RangeKeyScanStateEncoder( + dataEncoder, keySchema, orderingOrdinals, useColumnFamilies, virtualColFamilyId) + } + override def jsonValue: JValue = { ("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~ ("orderingOrdinals" -> orderingOrdinals.map(JInt(_))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 93e3112b4f3a3..55a46f51f9f6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,7 +41,7 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest + extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled with AlsoTestWithEncodingTypes { import testImplicits._ From ff1e0f489d90d00e56f4fb69836e97532c305e8c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 13 Dec 2024 13:42:28 -0800 Subject: [PATCH 10/17] moving avro enc creation into RocksDBStateENcoder --- .../streaming/state/RocksDBStateEncoder.scala | 80 ++++++++++++++++++- .../state/RocksDBStateStoreProvider.scala | 76 +----------------- 2 files changed, 78 insertions(+), 78 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index c3d541a29ab3e..1d8a6d004ef97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -27,7 +27,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri import org.apache.avro.io.{DecoderFactory, EncoderFactory} import org.apache.spark.internal.Logging -import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter @@ -423,10 +423,10 @@ class UnsafeRowDataEncoder( class AvroStateEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, - valueSchema: StructType, - avroEncoder: AvroEncoder) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) + valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) with Logging { + private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) // Avro schema used by the avro encoders private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema) private lazy val keyProj = UnsafeProjection.create(keySchema) @@ -478,6 +478,80 @@ class AvroStateEncoder( private lazy val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) + + + private def getAvroSerializer(schema: StructType): AvroSerializer = { + val avroType = SchemaConverters.toAvroType(schema) + new AvroSerializer(schema, avroType, nullable = false) + } + + private def getAvroDeserializer(schema: StructType): AvroDeserializer = { + val avroType = SchemaConverters.toAvroType(schema) + val avroOptions = AvroOptions(Map.empty) + new AvroDeserializer(avroType, schema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + } + + /** + * Creates an AvroEncoder that handles both key and value serialization/deserialization. + * This method sets up the complete encoding infrastructure needed for state store operations. + * + * The encoder handles different key encoding specifications: + * - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix + * - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning + * - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range scans + * + * For prefix scan cases, it also creates separate encoders for the suffix portion of keys. + * + * @param keyStateEncoderSpec Specification for how to encode keys + * @param valueSchema Schema for the values to be encoded + * @return An AvroEncoder containing all necessary serializers and deserializers + */ + private def createAvroEnc( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType + ): AvroEncoder = { + val valueSerializer = getAvroSerializer(valueSchema) + val valueDeserializer = getAvroDeserializer(valueSchema) + + // Get key schema based on encoder spec type + val keySchema = keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(schema) => + schema + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + StructType(schema.take(numColsPrefixKey)) + case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => + val remainingSchema = { + 0.until(schema.length).diff(orderingOrdinals).map { ordinal => + schema(ordinal) + } + } + StructType(remainingSchema) + } + + // Handle suffix key schema for prefix scan case + val suffixKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + Some(StructType(schema.drop(numColsPrefixKey))) + case _ => + None + } + + val keySerializer = getAvroSerializer(keySchema) + val keyDeserializer = getAvroDeserializer(keySchema) + + // Create the AvroEncoder with all components + AvroEncoder( + keySerializer, + keyDeserializer, + valueSerializer, + valueDeserializer, + suffixKeySchema.map(getAvroSerializer), + suffixKeySchema.map(getAvroDeserializer) + ) + } + /** * This method takes an UnsafeRow, and serializes to a byte array using Avro encoding. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1bc74dd4057c4..44751ddc45e60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -30,7 +30,6 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec -import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} @@ -664,8 +663,7 @@ object RocksDBStateStoreProvider { new java.util.concurrent.Callable[RocksDBDataEncoder] { override def call(): RocksDBDataEncoder = { if (stateStoreEncoding == "avro") { - val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) - new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder) + new AvroStateEncoder(keyStateEncoderSpec, valueSchema) } else { new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema) } @@ -684,78 +682,6 @@ object RocksDBStateStoreProvider { } } - private def getAvroSerializer(schema: StructType): AvroSerializer = { - val avroType = SchemaConverters.toAvroType(schema) - new AvroSerializer(schema, avroType, nullable = false) - } - - private def getAvroDeserializer(schema: StructType): AvroDeserializer = { - val avroType = SchemaConverters.toAvroType(schema) - val avroOptions = AvroOptions(Map.empty) - new AvroDeserializer(avroType, schema, - avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) - } - - /** - * Creates an AvroEncoder that handles both key and value serialization/deserialization. - * This method sets up the complete encoding infrastructure needed for state store operations. - * - * The encoder handles different key encoding specifications: - * - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix - * - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning - * - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range scans - * - * For prefix scan cases, it also creates separate encoders for the suffix portion of keys. - * - * @param keyStateEncoderSpec Specification for how to encode keys - * @param valueSchema Schema for the values to be encoded - * @return An AvroEncoder containing all necessary serializers and deserializers - */ - private def createAvroEnc( - keyStateEncoderSpec: KeyStateEncoderSpec, - valueSchema: StructType - ): AvroEncoder = { - val valueSerializer = getAvroSerializer(valueSchema) - val valueDeserializer = getAvroDeserializer(valueSchema) - - // Get key schema based on encoder spec type - val keySchema = keyStateEncoderSpec match { - case NoPrefixKeyStateEncoderSpec(schema) => - schema - case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => - StructType(schema.take(numColsPrefixKey)) - case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => - val remainingSchema = { - 0.until(schema.length).diff(orderingOrdinals).map { ordinal => - schema(ordinal) - } - } - StructType(remainingSchema) - } - - // Handle suffix key schema for prefix scan case - val suffixKeySchema = keyStateEncoderSpec match { - case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => - Some(StructType(schema.drop(numColsPrefixKey))) - case _ => - None - } - - val keySerializer = getAvroSerializer(keySchema) - val keyDeserializer = getAvroDeserializer(keySchema) - - // Create the AvroEncoder with all components - AvroEncoder( - keySerializer, - keyDeserializer, - valueSerializer, - valueDeserializer, - suffixKeySchema.map(getAvroSerializer), - suffixKeySchema.map(getAvroDeserializer) - ) - } - // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( From a29d14de909ac6c4b166b8e7fa55be6ff0134778 Mon Sep 17 00:00:00 2001 From: Eric Marnadi <132308037+ericm-db@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:54:24 -0800 Subject: [PATCH 11/17] Update TransformWithMapStateSuite.scala --- .../apache/spark/sql/streaming/TransformWithMapStateSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 574d2245dea93..6884ef577f8ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -82,7 +82,6 @@ class TestMapStateProcessor */ class TransformWithMapStateSuite extends StreamTest with AlsoTestWithRocksDBFeatures with AlsoTestWithEncodingTypes { - import testImplicits._ private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = { From e3dc8a5e7182de0c53b9ff85ed31821e4d00a5de Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 14 Dec 2024 00:01:39 -0800 Subject: [PATCH 12/17] fixed stream-stream join --- .../streaming/state/RocksDBStateEncoder.scala | 1 + .../state/RocksDBStateStoreProvider.scala | 29 +++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 1d8a6d004ef97..22cd10bef45a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -148,6 +148,7 @@ trait DataEncoder { */ def decodeValue(bytes: Array[Byte]): UnsafeRow } + abstract class RocksDBDataEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType) extends DataEncoder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 44751ddc45e60..fb0bf84d7aabc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -75,8 +75,12 @@ private[sql] class RocksDBStateStoreProvider isInternal: Boolean = false): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) - val dataEncoderCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + - s"${stateStoreId.partitionId}_${colFamilyName}" + val dataEncoderCacheKey = StateRowEncoderCacheKey( + queryRunId = getRunId(hadoopConf), + operatorId = stateStoreId.operatorId, + partitionId = stateStoreId.partitionId, + stateStoreName = stateStoreId.storeName, + colFamilyName = colFamilyName) val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) @@ -393,8 +397,12 @@ private[sql] class RocksDBStateStoreProvider defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME)) } - val dataEncoderCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + - s"${stateStoreId.partitionId}_${StateStore.DEFAULT_COL_FAMILY_NAME}" + val dataEncoderCacheKey = StateRowEncoderCacheKey( + queryRunId = getRunId(hadoopConf), + operatorId = stateStoreId.operatorId, + partitionId = stateStoreId.partitionId, + stateStoreName = stateStoreId.storeName, + colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME) val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) @@ -618,6 +626,15 @@ private[sql] class RocksDBStateStoreProvider } } + +case class StateRowEncoderCacheKey( + queryRunId: String, + operatorId: Long, + partitionId: Int, + stateStoreName: String, + colFamilyName: String +) + object RocksDBStateStoreProvider { // Version as a single byte that specifies the encoding of the row data in RocksDB val STATE_ENCODING_NUM_VERSION_BYTES = 1 @@ -628,7 +645,7 @@ object RocksDBStateStoreProvider { private val AVRO_ENCODER_LIFETIME_HOURS = 1L // Add the cache at companion object level so it persists across provider instances - private val dataEncoderCache: NonFateSharingCache[String, RocksDBDataEncoder] = + private val dataEncoderCache: NonFateSharingCache[StateRowEncoderCacheKey, RocksDBDataEncoder] = NonFateSharingCache( maximumSize = MAX_AVRO_ENCODERS_IN_CACHE, expireAfterAccessTime = AVRO_ENCODER_LIFETIME_HOURS, @@ -654,7 +671,7 @@ object RocksDBStateStoreProvider { */ def getDataEncoder( stateStoreEncoding: String, - encoderCacheKey: String, + encoderCacheKey: StateRowEncoderCacheKey, keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType): RocksDBDataEncoder = { assert(Set("avro", "unsaferow").contains(stateStoreEncoding)) From 83539087ed3b89f8074ea8f8e6699985858a4854 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 16 Dec 2024 20:10:07 -0800 Subject: [PATCH 13/17] feedback, scaladoc --- .../streaming/state/RocksDBStateEncoder.scala | 49 ++++++++++--------- .../streaming/state/StateStore.scala | 1 + 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 22cd10bef45a0..d76a6ffa72eed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -60,20 +60,17 @@ sealed trait RocksDBValueStateEncoder { * by the callers. The metadata in each row does not need to be written as Avro or UnsafeRow, * but the actual data provided by the caller does. */ -/** Interface for encoding and decoding state store data between UnsafeRow and raw bytes. - * - * @note All encode methods expect non-null input rows. Handling of null values is left to the - * implementing classes. - */ trait DataEncoder { - /** Encodes a complete key row into bytes. Used as the primary key for state lookups. + /** + * Encodes a complete key row into bytes. Used as the primary key for state lookups. * * @param row An UnsafeRow containing all key columns as defined in the keySchema * @return Serialized byte array representation of the key */ def encodeKey(row: UnsafeRow): Array[Byte] - /** Encodes the non-prefix portion of a key row. Used with prefix scan and + /** + * Encodes the non-prefix portion of a key row. Used with prefix scan and * range scan state lookups where the key is split into prefix and remaining portions. * * For prefix scans: Encodes columns after the prefix columns @@ -85,7 +82,8 @@ trait DataEncoder { */ def encodeRemainingKey(row: UnsafeRow): Array[Byte] - /** Encodes key columns used for range scanning, ensuring proper sort order in RocksDB. + /** + * Encodes key columns used for range scanning, ensuring proper sort order in RocksDB. * * This method handles special encoding for numeric types to maintain correct sort order: * - Adds sign byte markers for numeric types @@ -99,14 +97,16 @@ trait DataEncoder { */ def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] - /** Encodes a value row into bytes. + /** + * Encodes a value row into bytes. * * @param row An UnsafeRow containing the value columns as defined in the valueSchema * @return Serialized byte array representation of the value */ def encodeValue(row: UnsafeRow): Array[Byte] - /** Decodes a complete key from its serialized byte form. + /** + * Decodes a complete key from its serialized byte form. * * For NoPrefixKeyStateEncoder: Decodes the entire key * For PrefixKeyScanStateEncoder: Decodes only the prefix portion @@ -117,7 +117,8 @@ trait DataEncoder { */ def decodeKey(bytes: Array[Byte]): UnsafeRow - /** Decodes the remaining portion of a split key from its serialized form. + /** + * Decodes the remaining portion of a split key from its serialized form. * * For PrefixKeyScanStateEncoder: Decodes columns after the prefix * For RangeKeyScanStateEncoder: Decodes non-ordering columns @@ -128,7 +129,8 @@ trait DataEncoder { */ def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow - /** Decodes range scan key bytes back into an UnsafeRow, preserving proper ordering. + /** + * Decodes range scan key bytes back into an UnsafeRow, preserving proper ordering. * * This method reverses the special encoding done by encodePrefixKeyForRangeScan: * - Interprets sign byte markers @@ -141,7 +143,8 @@ trait DataEncoder { */ def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow - /** Decodes a value from its serialized byte form. + /** + * Decodes a value from its serialized byte form. * * @param bytes Serialized byte array containing the encoded value * @return UnsafeRow containing the decoded value columns @@ -479,8 +482,6 @@ class AvroStateEncoder( private lazy val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) - - private def getAvroSerializer(schema: StructType): AvroSerializer = { val avroType = SchemaConverters.toAvroType(schema) new AvroSerializer(schema, avroType, nullable = false) @@ -510,9 +511,8 @@ class AvroStateEncoder( * @return An AvroEncoder containing all necessary serializers and deserializers */ private def createAvroEnc( - keyStateEncoderSpec: KeyStateEncoderSpec, - valueSchema: StructType - ): AvroEncoder = { + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType): AvroEncoder = { val valueSerializer = getAvroSerializer(valueSchema) val valueDeserializer = getAvroDeserializer(valueSchema) @@ -900,6 +900,7 @@ abstract class RocksDBKeyStateEncoderBase( if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0 val out = new ByteArrayOutputStream + /** * Get Byte Array for the virtual column family id that is used as prefix for * key state rows. @@ -940,7 +941,8 @@ abstract class RocksDBKeyStateEncoderBase( } } -/** Factory object for creating state encoders used by RocksDB state store. +/** + * Factory object for creating state encoders used by RocksDB state store. * * The encoders created by this object handle serialization and deserialization of state data, * supporting both key and value encoding with various access patterns @@ -948,7 +950,8 @@ abstract class RocksDBKeyStateEncoderBase( */ object RocksDBStateEncoder extends Logging { - /** Creates a key encoder based on the specified encoding strategy and configuration. + /** + * Creates a key encoder based on the specified encoding strategy and configuration. * * @param dataEncoder The underlying encoder that handles the actual data encoding/decoding * @param keyStateEncoderSpec Specification defining the key encoding strategy @@ -965,7 +968,8 @@ object RocksDBStateEncoder extends Logging { keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies, virtualColFamilyId) } - /** Creates a value encoder that supports either single or multiple values per key. + /** + * Creates a value encoder that supports either single or multiple values per key. * * @param dataEncoder The underlying encoder that handles the actual data encoding/decoding * @param valueSchema Schema defining the structure of values to be encoded @@ -984,7 +988,8 @@ object RocksDBStateEncoder extends Logging { } } - /** Encodes a virtual column family ID into a byte array suitable for RocksDB. + /** + * Encodes a virtual column family ID into a byte array suitable for RocksDB. * * This method creates a fixed-size byte array prefixed with the virtual column family ID, * which is used to partition data within RocksDB. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ee1c558cd4bfc..373990d7118a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -376,6 +376,7 @@ case class PrefixKeyScanStateEncoderSpec( if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString) } + override def toEncoder( dataEncoder: RocksDBDataEncoder, useColumnFamilies: Boolean, From ec9e300d30698b887ade6cfc6c5920d7ec35928f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 16 Dec 2024 20:18:45 -0800 Subject: [PATCH 14/17] removing the avroEnc --- .../streaming/state/RocksDBStateEncoder.scala | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index d76a6ffa72eed..c415536a6f8d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -1309,7 +1309,6 @@ class RangeKeyScanStateEncoder( * It uses the first byte of the generated byte array to store the version the describes how the * row is encoded in the rest of the byte array. Currently, the default version is 0, * - * If the avroEnc is specified, we are using Avro encoding for this column family's keys * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ] * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, @@ -1319,16 +1318,9 @@ class NoPrefixKeyStateEncoder( dataEncoder: RocksDBDataEncoder, keySchema: StructType, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None, - avroEnc: Option[AvroEncoder] = None) + virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { - // Reusable objects - private val usingAvroEncoding = avroEnc.isDefined - private val keyRow = new UnsafeRow(keySchema.size) - private lazy val keyAvroType = SchemaConverters.toAvroType(keySchema) - private val keyProj = UnsafeProjection.create(keySchema) - override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { dataEncoder.encodeKey(row) @@ -1394,7 +1386,6 @@ class NoPrefixKeyStateEncoder( * This encoder supports RocksDB StringAppendOperator merge operator. Values encoded can be * merged in RocksDB using merge operation, and all merged values can be read using decodeValues * operation. - * If the avroEnc is specified, we are using Avro encoding for this column family's values */ class MultiValuedStateEncoder( dataEncoder: RocksDBDataEncoder, @@ -1467,7 +1458,6 @@ class MultiValuedStateEncoder( * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, * then the generated array byte will be N+1 bytes. - * If the avroEnc is specified, we are using Avro encoding for this column family's values */ class SingleValueStateEncoder( dataEncoder: RocksDBDataEncoder, From d12d877c046cfc57a657dbd0eddbec582645272c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 16 Dec 2024 20:30:01 -0800 Subject: [PATCH 15/17] removing avroenc from scaladoc --- .../sql/execution/streaming/state/RocksDBStateEncoder.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index c415536a6f8d0..9e8c895ff8bd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -1011,8 +1011,6 @@ object RocksDBStateEncoder extends Logging { * @param keySchema - schema of the key to be encoded * @param numColsPrefixKey - number of columns to be used for prefix key * @param useColumnFamilies - if column family is enabled for this encoder - * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will - * be defined */ class PrefixKeyScanStateEncoder( dataEncoder: RocksDBDataEncoder, @@ -1138,8 +1136,6 @@ class PrefixKeyScanStateEncoder( * @param keySchema - schema of the key to be encoded * @param orderingOrdinals - the ordinals for which the range scan is constructed * @param useColumnFamilies - if column family is enabled for this encoder - * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will - * be defined */ class RangeKeyScanStateEncoder( dataEncoder: RocksDBDataEncoder, From 2dcf844d32e082a794361a5317ae7d5aad75dacd Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 16 Dec 2024 20:56:27 -0800 Subject: [PATCH 16/17] adding scaladoc --- .../sql/execution/streaming/state/RocksDBStateEncoder.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 9e8c895ff8bd0..b4f6197811939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -59,6 +59,8 @@ sealed trait RocksDBValueStateEncoder { * headers, footers and other metadata, but they also have data that is provided * by the callers. The metadata in each row does not need to be written as Avro or UnsafeRow, * but the actual data provided by the caller does. + * The classes that use this trait require specialized partial encoding which makes them much + * easier to cache and use, which is why each DataEncoder deals with multiple schemas. */ trait DataEncoder { /** From 5028d662c9f3c88bf13a4402f6a6604a35a95c4e Mon Sep 17 00:00:00 2001 From: Eric Marnadi <132308037+ericm-db@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:03:01 -0800 Subject: [PATCH 17/17] Update StateStore.scala --- .../spark/sql/execution/streaming/state/StateStore.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 373990d7118a7..de10518035e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -326,7 +326,8 @@ sealed trait KeyStateEncoderSpec { def jsonValue: JValue def json: String = compact(render(jsonValue)) - /** Creates a RocksDBKeyStateEncoder for this specification. + /** + * Creates a RocksDBKeyStateEncoder for this specification. * * @param dataEncoder The encoder to handle the actual data encoding/decoding * @param useColumnFamilies Whether to use RocksDB column families