From 5bc6578c602374203c97c08d1f8aa6300e205d2c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 16 Nov 2024 12:52:43 -0800 Subject: [PATCH 1/9] [WIP] Avro encoding only in StateStore code --- .../apache/spark/sql/internal/SQLConf.scala | 13 + .../StateStoreColumnFamilySchemaUtils.scala | 41 +- .../streaming/state/RocksDBStateEncoder.scala | 499 ++++++++++++++++-- .../state/RocksDBStateStoreProvider.scala | 68 ++- .../StateSchemaCompatibilityChecker.scala | 17 +- .../streaming/state/StateStore.scala | 12 + .../streaming/state/StateStoreConf.scala | 3 + .../streaming/state/ListStateSuite.scala | 14 +- .../streaming/state/MapStateSuite.scala | 12 +- .../state/RocksDBStateStoreSuite.scala | 59 ++- .../streaming/state/RocksDBSuite.scala | 67 +++ .../streaming/state/TimerSuite.scala | 2 +- .../streaming/state/ValueStateSuite.scala | 28 +- .../TransformWithListStateSuite.scala | 16 +- .../TransformWithListStateTTLSuite.scala | 6 +- .../TransformWithMapStateSuite.scala | 10 +- .../TransformWithMapStateTTLSuite.scala | 4 +- .../streaming/TransformWithStateSuite.scala | 58 +- .../streaming/TransformWithStateTTLTest.scala | 4 +- .../TransformWithValueStateTTLSuite.scala | 5 +- 20 files changed, 790 insertions(+), 148 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5218a683a8fa8..d1c4120b9b854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2230,6 +2230,17 @@ object SQLConf { .intConf .createWithDefault(1) + val STREAMING_STATE_STORE_ENCODING_FORMAT = + buildConf("spark.sql.streaming.stateStore.encodingFormat") + .doc("The encoding format used for stateful operators to store information " + + "in the state store") + .version("4.0.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValue(v => Set("unsaferow", "avro").contains(v), + "Valid values are 'unsaferow' and 'avro'") + .createWithDefault("unsaferow") + val STATE_STORE_COMPRESSION_CODEC = buildConf("spark.sql.streaming.stateStore.compression.codec") .internal() @@ -5598,6 +5609,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreCheckpointFormatVersion: Int = getConf(STATE_STORE_CHECKPOINT_FORMAT_VERSION) + def stateStoreEncodingFormat: String = getConf(STREAMING_STATE_STORE_ENCODING_FORMAT) + def checkpointRenamedFileCheck: Boolean = getConf(CHECKPOINT_RENAMEDFILE_CHECK_ENABLED) def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 7da8408f98b0f..585298fa4c993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -20,10 +20,49 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ object StateStoreColumnFamilySchemaUtils { + /** + * Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans + * we want to use big-endian encoding, so we need to convert the source schema to replace these + * types with BinaryType. + * + * @param schema The schema to convert + * @param ordinals If non-empty, only convert fields at these ordinals. + * If empty, convert all fields. + */ + def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = { + val ordinalSet = ordinals.toSet + + StructType(schema.fields.zipWithIndex.flatMap { case (field, idx) => + if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) { + // For each numeric field, create two fields: + // 1. Byte marker for null, positive, or negative values + // 2. The original numeric value in big-endian format + // Byte type is converted to Int in Avro, which doesn't work for us as Avro + // uses zig-zag encoding as opposed to big-endian for Ints + Seq( + StructField(s"${field.name}_marker", BinaryType, nullable = false), + field.copy(name = s"${field.name}_value", BinaryType) + ) + } else { + Seq(field) + } + }) + } + + private def isFixedSize(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType => true + case _ => false + } + + def getTtlColFamilyName(stateName: String): String = { + "$ttl_" + stateName + } + def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 4c7a226e0973f..f097a4bb265e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -17,13 +17,21 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.ByteArrayOutputStream import java.lang.Double.{doubleToRawLongBits, longBitsToDouble} import java.lang.Float.{floatToRawIntBits, intBitsToFloat} import java.nio.{ByteBuffer, ByteOrder} +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.io.{DecoderFactory, EncoderFactory} + import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -49,6 +57,7 @@ abstract class RocksDBKeyStateEncoderBase( def offsetForColFamilyPrefix: Int = if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0 + val out = new ByteArrayOutputStream /** * Get Byte Array for the virtual column family id that is used as prefix for * key state rows. @@ -89,23 +98,24 @@ abstract class RocksDBKeyStateEncoderBase( } } -object RocksDBStateEncoder { +object RocksDBStateEncoder extends Logging { def getKeyEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, useColumnFamilies: Boolean, - virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = { // Return the key state encoder based on the requested type keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(keySchema) => - new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId) + new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, virtualColFamilyId, avroEnc) case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey, - useColumnFamilies, virtualColFamilyId) + useColumnFamilies, virtualColFamilyId, avroEnc) case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => new RangeKeyScanStateEncoder(keySchema, orderingOrdinals, - useColumnFamilies, virtualColFamilyId) + useColumnFamilies, virtualColFamilyId, avroEnc) case _ => throw new IllegalArgumentException(s"Unsupported key state encoder spec: " + @@ -115,11 +125,12 @@ object RocksDBStateEncoder { def getValueEncoder( valueSchema: StructType, - useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = { + useMultipleValuesPerKey: Boolean, + avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = { if (useMultipleValuesPerKey) { - new MultiValuedStateEncoder(valueSchema) + new MultiValuedStateEncoder(valueSchema, avroEnc) } else { - new SingleValueStateEncoder(valueSchema) + new SingleValueStateEncoder(valueSchema, avroEnc) } } @@ -145,6 +156,26 @@ object RocksDBStateEncoder { encodedBytes } + /** + * This method takes an UnsafeRow, and serializes to a byte array using Avro encoding. + */ + def encodeUnsafeRowToAvro( + row: UnsafeRow, + avroSerializer: AvroSerializer, + valueAvroType: Schema, + out: ByteArrayOutputStream): Array[Byte] = { + // InternalRow -> Avro.GenericDataRecord + val avroData = + avroSerializer.serialize(row) + out.reset() + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + valueAvroType) // Defining Avro writer for this struct type + writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array + encoder.flush() + out.toByteArray + } + def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = { if (bytes != null) { val row = new UnsafeRow(numFields) @@ -154,6 +185,26 @@ object RocksDBStateEncoder { } } + /** + * This method takes a byte array written using Avro encoding, and + * deserializes to an UnsafeRow using the Avro deserializer + */ + def decodeFromAvroToUnsafeRow( + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + val reader = new GenericDatumReader[Any](valueAvroType) + val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0, valueBytes.length, null) + // bytes -> Avro.GenericDataRecord + val genericData = reader.read(null, decoder) + // Avro.GenericDataRecord -> InternalRow + val internalRow = avroDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] + // InternalRow -> UnsafeRow + valueProj.apply(internalRow) + } + def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { if (bytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. @@ -174,16 +225,20 @@ object RocksDBStateEncoder { * @param keySchema - schema of the key to be encoded * @param numColsPrefixKey - number of columns to be used for prefix key * @param useColumnFamilies - if column family is enabled for this encoder + * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will + * be defined */ class PrefixKeyScanStateEncoder( keySchema: StructType, numColsPrefixKey: Int, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { keySchema.zipWithIndex.take(numColsPrefixKey) } @@ -203,6 +258,18 @@ class PrefixKeyScanStateEncoder( UnsafeProjection.create(refs) } + // Prefix Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey)) + private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) + private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) + + // Remaining Key schema and projection definitions used by the Avro Serializers + // and Deserializers + private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey)) + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema) + // This is quite simple to do - just bind sequentially, as we don't change the order. private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema) @@ -210,9 +277,24 @@ class PrefixKeyScanStateEncoder( private val joinedRowOnKey = new JoinedRow() override def encodeKey(row: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row)) - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) - + val (prefixKeyEncoded, remainingEncoded) = if (usingAvroEncoding) { + ( + encodeUnsafeRowToAvro( + extractPrefixKey(row), + avroEnc.get.keySerializer, + prefixKeyAvroType, + out + ), + encodeUnsafeRowToAvro( + remainingKeyProjection(row), + avroEnc.get.suffixKeySerializer.get, + remainingKeyAvroType, + out + ) + ) + } else { + (encodeUnsafeRow(extractPrefixKey(row)), encodeUnsafeRow(remainingKeyProjection(row))) + } val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + remainingEncoded.length + 4 ) @@ -243,9 +325,25 @@ class PrefixKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - numColsPrefixKey) + val (prefixKeyDecoded, remainingKeyDecoded) = if (usingAvroEncoding) { + ( + decodeFromAvroToUnsafeRow( + prefixKeyEncoded, + avroEnc.get.keyDeserializer, + prefixKeyAvroType, + prefixKeyProj + ), + decodeFromAvroToUnsafeRow( + remainingKeyEncoded, + avroEnc.get.suffixKeyDeserializer.get, + remainingKeyAvroType, + remainingKeyProj + ) + ) + } else { + (decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey), + decodeToUnsafeRow(remainingKeyEncoded, numFields = keySchema.length - numColsPrefixKey)) + } restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) } @@ -255,7 +353,11 @@ class PrefixKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val prefixKeyEncoded = encodeUnsafeRow(prefixKey) + val prefixKeyEncoded = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(prefixKey, avroEnc.get.keySerializer, prefixKeyAvroType, out) + } else { + encodeUnsafeRow(prefixKey) + } val (prefix, startingOffset) = encodeColumnFamilyPrefix( prefixKeyEncoded.length + 4 ) @@ -299,13 +401,16 @@ class PrefixKeyScanStateEncoder( * @param keySchema - schema of the key to be encoded * @param orderingOrdinals - the ordinals for which the range scan is constructed * @param useColumnFamilies - if column family is enabled for this encoder + * @param avroEnc - if Avro encoding is specified for this StateEncoder, this encoder will + * be defined */ class RangeKeyScanStateEncoder( keySchema: StructType, orderingOrdinals: Seq[Int], useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { import RocksDBStateEncoder._ @@ -374,6 +479,22 @@ class RangeKeyScanStateEncoder( UnsafeProjection.create(refs) } + private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan( + StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray)) + + private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema) + + private val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema) + + // Existing remainder key schema stuff + private val remainingKeySchema = StructType( + 0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_)) + ) + + private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema) + + private val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema) + // Reusable objects private val joinedRowOnKey = new JoinedRow() @@ -563,13 +684,233 @@ class RangeKeyScanStateEncoder( writer.getRow() } + def encodePrefixKeyForRangeScan( + row: UnsafeRow, + avroType: Schema): Array[Byte] = { + val record = new GenericData.Record(avroType) + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val value = row.get(idx, field.dataType) + + // Create marker byte buffer + val markerBuffer = ByteBuffer.allocate(1) + markerBuffer.order(ByteOrder.BIG_ENDIAN) + + if (value == null) { + markerBuffer.put(nullValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + record.put(fieldIdx + 1, ByteBuffer.wrap(new Array[Byte](field.dataType.defaultSize))) + } else { + field.dataType match { + case BooleanType => + markerBuffer.put(positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + val valueBuffer = ByteBuffer.allocate(1) + valueBuffer.put(if (value.asInstanceOf[Boolean]) 1.toByte else 0.toByte) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case ByteType => + val byteVal = value.asInstanceOf[Byte] + markerBuffer.put(if (byteVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(1) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.put(byteVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case ShortType => + val shortVal = value.asInstanceOf[Short] + markerBuffer.put(if (shortVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(2) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putShort(shortVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case IntegerType => + val intVal = value.asInstanceOf[Int] + markerBuffer.put(if (intVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(4) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putInt(intVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case LongType => + val longVal = value.asInstanceOf[Long] + markerBuffer.put(if (longVal < 0) negativeValMarker else positiveValMarker) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(8) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + valueBuffer.putLong(longVal) + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case FloatType => + val floatVal = value.asInstanceOf[Float] + val rawBits = floatToRawIntBits(floatVal) + markerBuffer.put(if ((rawBits & floatSignBitMask) != 0) { + negativeValMarker + } else { + positiveValMarker + }) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(4) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & floatSignBitMask) != 0) { + val updatedVal = rawBits ^ floatFlipBitMask + valueBuffer.putFloat(intBitsToFloat(updatedVal)) + } else { + valueBuffer.putFloat(floatVal) + } + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case DoubleType => + val doubleVal = value.asInstanceOf[Double] + val rawBits = doubleToRawLongBits(doubleVal) + markerBuffer.put(if ((rawBits & doubleSignBitMask) != 0) { + negativeValMarker + } else { + positiveValMarker + }) + record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array())) + + val valueBuffer = ByteBuffer.allocate(8) + valueBuffer.order(ByteOrder.BIG_ENDIAN) + if ((rawBits & doubleSignBitMask) != 0) { + val updatedVal = rawBits ^ doubleFlipBitMask + valueBuffer.putDouble(longBitsToDouble(updatedVal)) + } else { + valueBuffer.putDouble(doubleVal) + } + record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array())) + + case _ => throw new UnsupportedOperationException( + s"Range scan encoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + out.reset() + val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType) + val encoder = EncoderFactory.get().binaryEncoder(out, null) + writer.write(record, encoder) + encoder.flush() + out.toByteArray + } + + def decodePrefixKeyForRangeScan( + bytes: Array[Byte], + avroType: Schema): UnsafeRow = { + + val reader = new GenericDatumReader[GenericRecord](avroType) + val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) + val record = reader.read(null, decoder) + + val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length) + rowWriter.resetRowWriter() + + var fieldIdx = 0 + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + + val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array() + val markerBuf = ByteBuffer.wrap(markerBytes) + markerBuf.order(ByteOrder.BIG_ENDIAN) + val marker = markerBuf.get() + + if (marker == nullValMarker) { + rowWriter.setNullAt(idx) + } else { + field.dataType match { + case BooleanType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + rowWriter.write(idx, bytes(0) == 1) + + case ByteType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.get()) + + case ShortType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getShort()) + + case IntegerType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getInt()) + + case LongType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + rowWriter.write(idx, valueBuf.getLong()) + + case FloatType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + if (marker == negativeValMarker) { + val floatVal = valueBuf.getFloat + val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask + rowWriter.write(idx, intBitsToFloat(updatedVal)) + } else { + rowWriter.write(idx, valueBuf.getFloat()) + } + + case DoubleType => + val bytes = record.get(fieldIdx + 1).asInstanceOf[ByteBuffer].array() + val valueBuf = ByteBuffer.wrap(bytes) + valueBuf.order(ByteOrder.BIG_ENDIAN) + if (marker == negativeValMarker) { + val doubleVal = valueBuf.getDouble + val updatedVal = doubleToRawLongBits(doubleVal) ^ doubleFlipBitMask + rowWriter.write(idx, longBitsToDouble(updatedVal)) + } else { + rowWriter.write(idx, valueBuf.getDouble()) + } + + case _ => throw new UnsupportedOperationException( + s"Range scan decoding not supported for data type: ${field.dataType}") + } + } + fieldIdx += 2 + } + + rowWriter.getRow() + } + override def encodeKey(row: UnsafeRow): Array[Byte] = { // This prefix key has the columns specified by orderingOrdinals val prefixKey = extractPrefixKey(row) - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = if (avroEnc.isDefined) { + encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType) + } else { + encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + } val result = if (orderingOrdinals.length < keySchema.length) { - val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) + val remainingEncoded = if (avroEnc.isDefined) { + encodeUnsafeRowToAvro( + remainingKeyProjection(row), + avroEnc.get.keySerializer, + remainingKeyAvroType, + out + ) + } else { + encodeUnsafeRow(remainingKeyProjection(row)) + } val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( rangeScanKeyEncoded.length + remainingEncoded.length + 4 ) @@ -606,9 +947,12 @@ class RangeKeyScanStateEncoder( Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4, prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) - val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded, - numFields = orderingOrdinals.length) - val prefixKeyDecoded = decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan) + val prefixKeyDecoded = if (avroEnc.isDefined) { + decodePrefixKeyForRangeScan(prefixKeyEncoded, rangeScanAvroType) + } else { + decodePrefixKeyForRangeScan(decodeToUnsafeRow(prefixKeyEncoded, + numFields = orderingOrdinals.length)) + } if (orderingOrdinals.length < keySchema.length) { // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes @@ -620,8 +964,14 @@ class RangeKeyScanStateEncoder( remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen) - val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - orderingOrdinals.length) + val remainingKeyDecoded = if (avroEnc.isDefined) { + decodeFromAvroToUnsafeRow(remainingKeyEncoded, + avroEnc.get.keyDeserializer, + remainingKeyAvroType, remainingKeyAvroProjection) + } else { + decodeToUnsafeRow(remainingKeyEncoded, + numFields = keySchema.length - orderingOrdinals.length) + } val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded) val restored = restoreKeyProjection(joined) @@ -634,7 +984,11 @@ class RangeKeyScanStateEncoder( } override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { - val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + val rangeScanKeyEncoded = if (avroEnc.isDefined) { + encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType) + } else { + encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) + } val (prefix, startingOffset) = encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4) Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length) @@ -653,6 +1007,7 @@ class RangeKeyScanStateEncoder( * It uses the first byte of the generated byte array to store the version the describes how the * row is encoded in the rest of the byte array. Currently, the default version is 0, * + * If the avroEnc is specified, we are using Avro encoding for this column family's keys * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ] * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, @@ -661,19 +1016,27 @@ class RangeKeyScanStateEncoder( class NoPrefixKeyStateEncoder( keySchema: StructType, useColumnFamilies: Boolean = false, - virtualColFamilyId: Option[Short] = None) - extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) { + virtualColFamilyId: Option[Short] = None, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { import RocksDBStateEncoder._ // Reusable objects + private val usingAvroEncoding = avroEnc.isDefined private val keyRow = new UnsafeRow(keySchema.size) + private lazy val keyAvroType = SchemaConverters.toAvroType(keySchema) + private val keyProj = UnsafeProjection.create(keySchema) override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { encodeUnsafeRow(row) } else { - val bytesToEncode = row.getBytes + // If avroEnc is defined, we know that we need to use Avro to + // encode this UnsafeRow to Avro bytes + val bytesToEncode = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.keySerializer, keyAvroType, out) + } else row.getBytes val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix( bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES @@ -697,11 +1060,21 @@ class NoPrefixKeyStateEncoder( if (useColumnFamilies) { if (keyBytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. - keyRow.pointTo( - keyBytes, - decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, - keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) - keyRow + if (usingAvroEncoding) { + val dataLength = keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - + VIRTUAL_COL_FAMILY_PREFIX_BYTES + val avroBytes = new Array[Byte](dataLength) + Platform.copyMemory( + keyBytes, decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + avroBytes, 0, dataLength) + decodeFromAvroToUnsafeRow(avroBytes, avroEnc.get.keyDeserializer, keyAvroType, keyProj) + } else { + keyRow.pointTo( + keyBytes, + decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, + keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES) + keyRow + } } else { null } @@ -727,17 +1100,28 @@ class NoPrefixKeyStateEncoder( * This encoder supports RocksDB StringAppendOperator merge operator. Values encoded can be * merged in RocksDB using merge operation, and all merged values can be read using decodeValues * operation. + * If the avroEnc is specified, we are using Avro encoding for this column family's values */ -class MultiValuedStateEncoder(valueSchema: StructType) +class MultiValuedStateEncoder( + valueSchema: StructType, + avroEnc: Option[AvroEncoder] = None) extends RocksDBValueStateEncoder with Logging { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined // Reusable objects + private val out = new ByteArrayOutputStream private val valueRow = new UnsafeRow(valueSchema.size) + private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema) + private val valueProj = UnsafeProjection.create(valueSchema) override def encodeValue(row: UnsafeRow): Array[Byte] = { - val bytes = encodeUnsafeRow(row) + val bytes = if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) + } else { + encodeUnsafeRow(row) + } val numBytes = bytes.length val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length) @@ -756,7 +1140,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) val encodedValue = new Array[Byte](numBytes) Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET, encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes) - decodeToUnsafeRow(encodedValue, valueRow) + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + encodedValue, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(encodedValue, valueRow) + } } } @@ -782,7 +1171,12 @@ class MultiValuedStateEncoder(valueSchema: StructType) pos += numBytes pos += 1 // eat the delimiter character - decodeToUnsafeRow(encodedValue, valueRow) + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + encodedValue, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(encodedValue, valueRow) + } } } } @@ -802,16 +1196,29 @@ class MultiValuedStateEncoder(valueSchema: StructType) * The bytes of a UnsafeRow is written unmodified to starting from offset 1 * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, * then the generated array byte will be N+1 bytes. + * If the avroEnc is specified, we are using Avro encoding for this column family's values */ -class SingleValueStateEncoder(valueSchema: StructType) - extends RocksDBValueStateEncoder { +class SingleValueStateEncoder( + valueSchema: StructType, + avroEnc: Option[AvroEncoder] = None) + extends RocksDBValueStateEncoder with Logging { import RocksDBStateEncoder._ + private val usingAvroEncoding = avroEnc.isDefined // Reusable objects + private val out = new ByteArrayOutputStream private val valueRow = new UnsafeRow(valueSchema.size) + private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema) + private val valueProj = UnsafeProjection.create(valueSchema) - override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) + override def encodeValue(row: UnsafeRow): Array[Byte] = { + if (usingAvroEncoding) { + encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) + } else { + encodeUnsafeRow(row) + } + } /** * Decode byte array for a value to a UnsafeRow. @@ -820,7 +1227,15 @@ class SingleValueStateEncoder(valueSchema: StructType) * the given byte array. */ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { - decodeToUnsafeRow(valueBytes, valueRow) + if (valueBytes == null) { + return null + } + if (usingAvroEncoding) { + decodeFromAvroToUnsafeRow( + valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + } else { + decodeToUnsafeRow(valueBytes, valueRow) + } } override def supportsMultipleValuesPerKey: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1fc6ab5910c6c..a0711c2921b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.CheckpointFileManager @@ -74,10 +75,68 @@ private[sql] class RocksDBStateStoreProvider isInternal: Boolean = false): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) + // Create cache key using store ID to avoid collisions + val avroEncCacheKey = s"${stateStoreId.operatorId}_" + + s"${stateStoreId.partitionId}_$colFamilyName" + + lazy val avroEnc = stateStoreEncoding match { + case "avro" => Some( + RocksDBStateStoreProvider.avroEncoderMap.computeIfAbsent(avroEncCacheKey, + _ => getAvroEnc(keyStateEncoderSpec, valueSchema)) + ) + case "unsaferow" => None + } + keyValueEncoderMap.putIfAbsent(colFamilyName, (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies, - Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(valueSchema, - useMultipleValuesPerKey))) + Some(newColFamilyId), avroEnc), RocksDBStateEncoder.getValueEncoder(valueSchema, + useMultipleValuesPerKey, avroEnc))) + } + private def getAvroSerializer(schema: StructType): AvroSerializer = { + val avroType = SchemaConverters.toAvroType(schema) + new AvroSerializer(schema, avroType, nullable = false) + } + + private def getAvroDeserializer(schema: StructType): AvroDeserializer = { + val avroType = SchemaConverters.toAvroType(schema) + val avroOptions = AvroOptions(Map.empty) + new AvroDeserializer(avroType, schema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + } + + private def getAvroEnc( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType + ): AvroEncoder = { + val valueSerializer = getAvroSerializer(valueSchema) + val valueDeserializer = getAvroDeserializer(valueSchema) + val keySchema = keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(schema) => + schema + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + StructType(schema.take(numColsPrefixKey)) + case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => + val remainingSchema = { + 0.until(schema.length).diff(orderingOrdinals).map { ordinal => + schema(ordinal) + } + } + StructType(remainingSchema) + } + val suffixKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + Some(StructType(schema.drop(numColsPrefixKey))) + case _ => None + } + AvroEncoder( + getAvroSerializer(keySchema), + getAvroDeserializer(keySchema), + valueSerializer, + valueDeserializer, + suffixKeySchema.map(getAvroSerializer), + suffixKeySchema.map(getAvroDeserializer) + ) } override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { @@ -364,6 +423,7 @@ private[sql] class RocksDBStateStoreProvider this.storeConf = storeConf this.hadoopConf = hadoopConf this.useColumnFamilies = useColumnFamilies + this.stateStoreEncoding = storeConf.stateStoreEncodingFormat if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -458,6 +518,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var useColumnFamilies: Boolean = _ + @volatile private var stateStoreEncoding: String = _ private[sql] lazy val rocksDB = { val dfsRootDir = stateStoreId.storeCheckpointLocation().toString @@ -593,6 +654,9 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 + // Add the cache at companion object level so it persists across provider instances + private val avroEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, AvroEncoder]() + // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 721d72b6a0991..b5f2f318de418 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer} import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} @@ -37,14 +38,26 @@ case class StateSchemaValidationResult( schemaPath: String ) +// Avro encoder that is used by the RocksDBStateStoreProvider and RocksDBStateEncoder +// in order to serialize from UnsafeRow to a byte array of Avro encoding. +case class AvroEncoder( + keySerializer: AvroSerializer, + keyDeserializer: AvroDeserializer, + valueSerializer: AvroSerializer, + valueDeserializer: AvroDeserializer, + suffixKeySerializer: Option[AvroSerializer] = None, + suffixKeyDeserializer: Option[AvroDeserializer] = None +) extends Serializable + // Used to represent the schema of a column family in the state store case class StateStoreColFamilySchema( colFamilyName: String, keySchema: StructType, valueSchema: StructType, keyStateEncoderSpec: Option[KeyStateEncoderSpec] = None, - userKeyEncoderSchema: Option[StructType] = None -) + userKeyEncoderSchema: Option[StructType] = None, + avroEnc: Option[AvroEncoder] = None +) extends Serializable class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 72bc3ca33054d..c3449b6a12626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -41,6 +41,18 @@ import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} +sealed trait StateStoreEncoding { + override def toString: String = this match { + case StateStoreEncoding.UnsafeRow => "unsaferow" + case StateStoreEncoding.Avro => "avro" + } +} + +object StateStoreEncoding { + case object UnsafeRow extends StateStoreEncoding + case object Avro extends StateStoreEncoding +} + /** * Base trait for a versioned key-value store which provides read operations. Each instance of a * `ReadStateStore` represents a specific version of state data, and such instances are created diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index c8af395e996d8..9d26bf8fdf2e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -83,6 +83,9 @@ class StateStoreConf( /** The interval of maintenance tasks. */ val maintenanceInterval = sqlConf.streamingMaintenanceInterval + /** The interval of maintenance tasks. */ + val stateStoreEncodingFormat = sqlConf.stateStoreEncodingFormat + /** * When creating new state store checkpoint, which format version to use. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index 22876831c00d1..b20e6d7d49a9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -60,7 +60,7 @@ class ListStateSuite extends StateVariableSuiteBase { } Seq("appendList", "put").foreach { listImplFunc => - test(s"Test list operation($listImplFunc) with null") { + testWithEncodingTypes(s"Test list operation($listImplFunc) with null") { testMapStateWithNullUserKey() { listState => listImplFunc match { case "appendList" => listState.appendList(null) @@ -70,7 +70,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - test("List state operations for single instance") { + testWithEncodingTypes("List state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -99,7 +99,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - test("List state operations for multiple instance") { + testWithEncodingTypes("List state operations for multiple instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -139,7 +139,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - test("List state operations with list, value, another list instances") { + testWithEncodingTypes("List state operations with list, value, another list instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -170,7 +170,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - test(s"test List state TTL") { + testWithEncodingTypes(s"test List state TTL") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -226,7 +226,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - test("test null or negative TTL duration throws error") { + testWithEncodingTypes("test null or negative TTL duration throws error") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 @@ -253,7 +253,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - test("ListState TTL with non-primitive types") { + testWithEncodingTypes("ListState TTL with non-primitive types") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index 9a0a891d538ec..b913133deddeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -38,7 +38,7 @@ class MapStateSuite extends StateVariableSuiteBase { import testImplicits._ - test("Map state operations for single instance") { + testWithEncodingTypes("Map state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -73,7 +73,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - test("Map state operations for multiple map instances") { + testWithEncodingTypes("Map state operations for multiple map instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -113,7 +113,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - test("Map state operations with list, value, another map instances") { + testWithEncodingTypes("Map state operations with list, value, another map instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -173,7 +173,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - test("test Map state TTL") { + testWithEncodingTypes("test Map state TTL") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -231,7 +231,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - test("test null or negative TTL duration throws error") { + testWithEncodingTypes("test null or negative TTL duration throws error") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 @@ -259,7 +259,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - test("Map state with TTL with non-primitive types") { + testWithEncodingTypes("Map state with TTL with non-primitive types") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e1bd9dd38066b..c895aced455e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -58,7 +58,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid import StateStoreTestsHelper._ - testWithColumnFamilies(s"version encoding", + testWithColumnFamiliesAndEncodingTypes(s"version encoding", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => import RocksDBStateStoreProvider._ @@ -127,7 +127,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb file manager metrics exposed", + testWithColumnFamiliesAndEncodingTypes("rocksdb file manager metrics exposed", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => import RocksDBStateStoreProvider._ def getCustomMetric(metrics: StateStoreMetrics, @@ -162,7 +162,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan validation - invalid num columns", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - invalid num columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => // zero ordering cols val ex1 = intercept[SparkUnsupportedOperationException] { @@ -201,7 +201,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } - testWithColumnFamilies("rocksdb range scan validation - variable sized columns", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - variable sized columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithVariableSizeCols: StructType = StructType( Seq(StructField("key1", StringType, false), StructField("key2", StringType, false))) @@ -224,7 +224,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } - testWithColumnFamilies("rocksdb range scan validation - variable size data types unsupported", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - " + + "variable size data types unsupported", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithSomeUnsupportedTypeCols: StructType = StructType(Seq( StructField("key1", StringType, false), @@ -264,7 +265,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan validation - null type columns", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - null type columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithNullTypeCols: StructType = StructType( Seq(StructField("key1", NullType, false), StructField("key2", StringType, false))) @@ -287,7 +288,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } - testWithColumnFamilies("rocksdb range scan - fixed size non-ordering columns", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - fixed size non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -339,7 +340,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan - variable size non-ordering columns with " + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - variable " + + "size non-ordering columns with " + "double type values are supported", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -395,7 +397,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan - variable size non-ordering columns", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - variable size non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -448,7 +450,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + + "ordering columns - variable size " + s"non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -492,15 +495,16 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan multiple non-contiguous ordering columns", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + + "non-contiguous ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled ) { colFamiliesEnabled => val testSchema: StructType = StructType( Seq( - StructField("ordering-1", LongType, false), + StructField("ordering1", LongType, false), StructField("key2", StringType, false), - StructField("ordering-2", IntegerType, false), - StructField("string-2", StringType, false), - StructField("ordering-3", DoubleType, false) + StructField("ordering2", IntegerType, false), + StructField("string2", StringType, false), + StructField("ordering3", DoubleType, false) ) ) @@ -582,7 +586,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } - testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + + "ordering columns - variable size " + s"non-ordering columns with null values in first ordering column", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -682,7 +687,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + + "ordering columns - variable size " + s"non-ordering columns with null values in second ordering column", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -735,7 +741,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan byte ordering column - variable size " + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan byte " + + "ordering column - variable size " + s"non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -779,7 +786,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan - ordering cols and key schema cols are same", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - ordering cols " + + "and key schema cols are same", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => // use the same schema as value schema for single col key schema @@ -821,7 +829,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan - with prefix scan", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - with prefix scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -858,7 +866,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb key and value schema encoders for column families", + testWithColumnFamiliesAndEncodingTypes("rocksdb key and value schema encoders " + + "for column families", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val testColFamily = "testState" @@ -919,7 +928,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } /* Column family related tests */ - testWithColumnFamilies("column family creation with invalid names", + testWithColumnFamiliesAndEncodingTypes("column family creation with invalid names", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource( newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => @@ -956,7 +965,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies(s"column family creation with reserved chars", + testWithColumnFamiliesAndEncodingTypes(s"column family creation with reserved chars", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource( newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => @@ -992,7 +1001,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies(s"operations on absent column family", + testWithColumnFamiliesAndEncodingTypes(s"operations on absent column family", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource( newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => @@ -1145,7 +1154,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid Seq( NoPrefixKeyStateEncoderSpec(keySchema), PrefixKeyScanStateEncoderSpec(keySchema, 1) ).foreach { keyEncoder => - testWithColumnFamilies(s"validate rocksdb " + + testWithColumnFamiliesAndEncodingTypes(s"validate rocksdb " + s"${keyEncoder.getClass.toString.split('.').last} correctness", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchema, keyEncoder, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 637eb49130305..5f3b62320a35c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -128,6 +128,73 @@ trait AlsoTestWithChangelogCheckpointingEnabled } } + def testWithEncodingTypes(testName: String, testTags: Tag*) + (testBody: => Any): Unit = { + Seq("unsaferow", "avro").foreach { encoding => + super.test(testName + s" (encoding = $encoding)", testTags: _*) { + // in case tests have any code that needs to execute before every test + // super.beforeEach() + withSQLConf( + SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> + encoding) { + testBody + } + // in case tests have any code that needs to execute after every test + // super.afterEach() + } + } + } + + def testWithColumnFamiliesAndEncodingTypes( + testName: String, + testMode: TestMode, + testTags: Tag*) + (testBody: Boolean => Any): Unit = { + + Seq(true, false).foreach { colFamiliesEnabled => + Seq("unsaferow", "avro").foreach { encoding => + testMode match { + case TestWithChangelogCheckpointingEnabled => + testWithChangelogCheckpointingEnabled( + s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", + testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody(colFamiliesEnabled) + } + } + + case TestWithChangelogCheckpointingDisabled => + testWithChangelogCheckpointingDisabled( + s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", + testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody(colFamiliesEnabled) + } + } + + case TestWithBothChangelogCheckpointingEnabledAndDisabled => + testWithChangelogCheckpointingEnabled( + s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", + testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody(colFamiliesEnabled) + } + } + testWithChangelogCheckpointingDisabled( + s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", + testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody(colFamiliesEnabled) + } + } + + case _ => + throw new IllegalArgumentException(s"Unknown test mode: $testMode") + } + } + } + } + def testWithColumnFamilies( testName: String, testMode: TestMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala index 24a120be9d9af..be52691315259 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -30,7 +30,7 @@ class TimerSuite extends StateVariableSuiteBase { private def testWithTimeMode(testName: String) (testFunc: TimeMode => Unit): Unit = { Seq("Processing", "Event").foreach { timeoutMode => - test(s"$timeoutMode timer - " + testName) { + testWithEncodingTypes(s"$timeoutMode timer - " + testName) { timeoutMode match { case "Processing" => testFunc(TimeMode.ProcessingTime()) case "Event" => testFunc(TimeMode.EventTime()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 55d08cd8f12a7..3d7b56a286b95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -46,7 +46,7 @@ class ValueStateSuite extends StateVariableSuiteBase { import StateStoreTestsHelper._ import testImplicits._ - test("Implicit key operations") { + testWithEncodingTypes("Implicit key operations") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -91,7 +91,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("Value state operations for single instance") { + testWithEncodingTypes("Value state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -118,7 +118,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("Value state operations for multiple instances") { + testWithEncodingTypes("Value state operations for multiple instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -163,7 +163,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("Value state operations for unsupported type name should fail") { + testWithEncodingTypes("Value state operations for unsupported type name should fail") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, @@ -184,7 +184,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("colFamily with HDFSBackedStateStoreProvider should fail") { + testWithEncodingTypes("colFamily with HDFSBackedStateStoreProvider should fail") { val storeId = StateStoreId(newDir(), Random.nextInt(), 0) val provider = new HDFSBackedStateStoreProvider() val storeConf = new StateStoreConf(new SQLConf()) @@ -203,7 +203,8 @@ class ValueStateSuite extends StateVariableSuiteBase { ) } - test("test SQL encoder - Value state operations for Primitive(Double) instances") { + testWithEncodingTypes("test SQL encoder - Value state operations for " + + "Primitive(Double) instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -229,7 +230,8 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("test SQL encoder - Value state operations for Primitive(Long) instances") { + testWithEncodingTypes("test SQL encoder - Value state operations " + + "for Primitive(Long) instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -255,7 +257,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("test SQL encoder - Value state operations for case class instances") { + testWithEncodingTypes("test SQL encoder - Value state operations for case class instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -281,7 +283,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("test SQL encoder - Value state operations for POJO instances") { + testWithEncodingTypes("test SQL encoder - Value state operations for POJO instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -307,7 +309,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test(s"test Value state TTL") { + testWithEncodingTypes(s"test Value state TTL") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -363,7 +365,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("test null or zero TTL duration throws error") { + testWithEncodingTypes("test null or zero TTL duration throws error") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 @@ -390,7 +392,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - test("Value state TTL with non-primitive type") { + testWithEncodingTypes("Value state TTL with non-primitive type") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -423,7 +425,7 @@ class ValueStateSuite extends StateVariableSuiteBase { * types (ValueState, ListState, MapState) used in arbitrary stateful operators. */ abstract class StateVariableSuiteBase extends SharedSparkSession - with BeforeAndAfter { + with BeforeAndAfter with AlsoTestWithChangelogCheckpointingEnabled { before { StateStore.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 88862e2ad0791..9f26de126e0dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -130,7 +130,7 @@ class TransformWithListStateSuite extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ - test("test appending null value in list state throw exception") { + testWithEncodingTypes("test appending null value in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -150,7 +150,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test putting null value in list state throw exception") { + testWithEncodingTypes("test putting null value in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -170,7 +170,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test putting null list in list state throw exception") { + testWithEncodingTypes("test putting null list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -190,7 +190,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test appending null list in list state throw exception") { + testWithEncodingTypes("test appending null list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -210,7 +210,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test putting empty list in list state throw exception") { + testWithEncodingTypes("test putting empty list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -230,7 +230,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test appending empty list in list state throw exception") { + testWithEncodingTypes("test appending empty list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -250,7 +250,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test list state correctness") { + testWithEncodingTypes("test list state correctness") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -307,7 +307,7 @@ class TransformWithListStateSuite extends StreamTest } } - test("test ValueState And ListState in Processor") { + testWithEncodingTypes("test ValueState And ListState in Processor") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index 409a255ae3e64..ebd29bff5d354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -105,7 +105,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numListStateWithTTLVars" - test("verify iterator works with expired values in beginning of list") { + testWithEncodingTypes("verify iterator works with expired values in beginning of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -195,7 +195,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { // ascending order of TTL by stopping the query, setting the new TTL, and restarting // the query to check that the expired elements in the middle or end of the list // are not returned. - test("verify iterator works with expired values in middle of list") { + testWithEncodingTypes("verify iterator works with expired values in middle of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -343,7 +343,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify iterator works with expired values in end of list") { + testWithEncodingTypes("verify iterator works with expired values in end of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 76c5cbeee424b..410ad55cad480 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -110,7 +110,7 @@ class TransformWithMapStateSuite extends StreamTest } } - test("Test retrieving value with non-existing user key") { + testWithEncodingTypes("Test retrieving value with non-existing user key") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -129,12 +129,12 @@ class TransformWithMapStateSuite extends StreamTest } Seq("getValue", "containsKey", "updateValue", "removeKey").foreach { mapImplFunc => - test(s"Test $mapImplFunc with null user key") { + testWithEncodingTypes(s"Test $mapImplFunc with null user key") { testMapStateWithNullUserKey(InputMapRow("k1", mapImplFunc, (null, ""))) } } - test("Test put value with null value") { + testWithEncodingTypes("Test put value with null value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -158,7 +158,7 @@ class TransformWithMapStateSuite extends StreamTest } } - test("Test map state correctness") { + testWithEncodingTypes("Test map state correctness") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputData = MemoryStream[InputMapRow] @@ -219,7 +219,7 @@ class TransformWithMapStateSuite extends StreamTest } } - test("transformWithMapState - batch should succeed") { + testWithEncodingTypes("transformWithMapState - batch should succeed") { val inputData = Seq( InputMapRow("k1", "updateValue", ("v1", "10")), InputMapRow("k1", "getValue", ("v1", ""))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index 022280eb3bcef..a68632534c001 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -182,7 +182,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numMapStateWithTTLVars" - test("validate state is evicted with multiple user keys") { + testWithEncodingTypes("validate state is evicted with multiple user keys") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -224,7 +224,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify iterator doesn't return expired keys") { + testWithEncodingTypes("verify iterator doesn't return expired keys") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 3ef5c57ee3d07..5a3cfa715e300 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -409,7 +409,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest import testImplicits._ - test("transformWithState - streaming with rocksdb and invalid processor should fail") { + testWithEncodingTypes("transformWithState - streaming with rocksdb and" + + " invalid processor should fail") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -430,7 +431,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - lazy iterators can properly get/set keyed state") { + testWithEncodingTypes("transformWithState - lazy iterators can properly get/set keyed state") { val spark = this.spark import spark.implicits._ @@ -508,7 +509,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb should succeed") { + testWithEncodingTypes("transformWithState - streaming with rocksdb should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -546,7 +547,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and processing time timer " + + testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + "should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -591,7 +592,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and processing time timer " + + testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + "and updating timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -627,7 +628,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and processing time timer " + + testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + "and multiple timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -664,7 +665,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - streaming with rocksdb and event time based timer") { + testWithEncodingTypes("transformWithState - streaming with rocksdb and event time based timer") { val inputData = MemoryStream[(String, Int)] val result = inputData.toDS() @@ -708,7 +709,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - test("Use statefulProcessor without transformWithState - handle should be absent") { + testWithEncodingTypes("Use statefulProcessor without " + + "transformWithState - handle should be absent") { val processor = new RunningCountStatefulProcessor() val ex = intercept[Exception] { processor.getHandle @@ -720,7 +722,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - test("transformWithState - batch should succeed") { + testWithEncodingTypes("transformWithState - batch should succeed") { val inputData = Seq("a", "b") val result = inputData.toDS() .groupByKey(x => x) @@ -732,7 +734,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) } - test("transformWithState - test deleteIfExists operator") { + testWithEncodingTypes("transformWithState - test deleteIfExists operator") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -773,7 +775,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - two input streams") { + testWithEncodingTypes("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -803,7 +805,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - three input streams") { + testWithEncodingTypes("transformWithState - three input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -838,7 +840,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - two input streams, different key type") { + testWithEncodingTypes("transformWithState - two input streams, different key type") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -885,7 +887,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest OutputMode.Update()) } - test("transformWithState - availableNow trigger mode, rate limit is respected") { + testWithEncodingTypes("transformWithState - availableNow trigger mode, rate limit is respected") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -926,7 +928,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - availableNow trigger mode, multiple restarts") { + testWithEncodingTypes("transformWithState - availableNow trigger mode, multiple restarts") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -964,7 +966,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { + testWithEncodingTypes("transformWithState - verify StateSchemaV3 writes " + + "correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1046,7 +1049,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify that OperatorStateMetadataV2" + + testWithEncodingTypes("transformWithState - verify that OperatorStateMetadataV2" + " file is being written correctly") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -1090,7 +1093,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that invalid schema evolution fails query for column family") { + testWithEncodingTypes("test that invalid schema evolution fails query for column family") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1127,7 +1130,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that different outputMode after query restart fails") { + testWithEncodingTypes("test that different outputMode after query restart fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1170,7 +1173,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that changing between different state variable types fails") { + testWithEncodingTypes("test that changing between different state variable types fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1212,7 +1215,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that different timeMode after query restart fails") { + testWithEncodingTypes("test that different timeMode after query restart fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1259,7 +1262,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test that introducing TTL after restart fails query") { + testWithEncodingTypes("test that introducing TTL after restart fails query") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1313,7 +1316,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test query restart with new state variable succeeds") { + testWithEncodingTypes("test query restart with new state variable succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1359,7 +1362,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test query restart succeeds") { + testWithEncodingTypes("test query restart succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1394,7 +1397,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("SPARK-49070: transformWithState - valid initial state plan") { + testWithEncodingTypes("SPARK-49070: transformWithState - valid initial state plan") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -1444,7 +1447,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest new Path(stateCheckpointPath, "_stateSchema/default/") } - test("transformWithState - verify that metadata and schema logs are purged") { + testWithEncodingTypes("transformWithState - verify that metadata and schema logs are purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1535,7 +1538,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify that schema file is kept after metadata is purged") { + testWithEncodingTypes("transformWithState - verify that schema file " + + "is kept after metadata is purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 2ddf69aa49e04..92a06f1183935 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,7 +41,7 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest { + extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, InputEvent, OutputEvent] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 21c3beb79314c..9727c2dc8c113 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -195,7 +195,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numValueStateWithTTLVars" - test("validate multiple value states") { + testWithEncodingTypes("validate multiple value states") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val ttlKey = "k1" @@ -262,7 +262,8 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { } } - test("verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") { + testWithEncodingTypes("verify StateSchemaV3 writes correct SQL " + + "schema of key/value and with TTL") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> From 6df3947d52e836943ccc90583acc5263ae6596bb Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 16 Nov 2024 12:54:37 -0800 Subject: [PATCH 2/9] removing unnecessary change --- .../streaming/state/StateSchemaCompatibilityChecker.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index b5f2f318de418..69a29cdbe7a17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -55,9 +55,8 @@ case class StateStoreColFamilySchema( keySchema: StructType, valueSchema: StructType, keyStateEncoderSpec: Option[KeyStateEncoderSpec] = None, - userKeyEncoderSchema: Option[StructType] = None, - avroEnc: Option[AvroEncoder] = None -) extends Serializable + userKeyEncoderSchema: Option[StructType] = None +) class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, From c6c7cc16353f2e06a89d43d87203f45fbd5698f3 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Sat, 16 Nov 2024 13:03:34 -0800 Subject: [PATCH 3/9] avro-temp --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 3 +++ .../apache/spark/sql/streaming/TransformWithStateSuite.scala | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index a0711c2921b51..bac7ba85ca4ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -79,6 +79,9 @@ private[sql] class RocksDBStateStoreProvider val avroEncCacheKey = s"${stateStoreId.operatorId}_" + s"${stateStoreId.partitionId}_$colFamilyName" + // If we have not created the avroEncoder for this column family, create + // it, or look in the cache maintained in the RocksDBStateStoreProvider + // companion object lazy val avroEnc = stateStoreEncoding match { case "avro" => Some( RocksDBStateStoreProvider.avroEncoderMap.computeIfAbsent(avroEncCacheKey, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 5a3cfa715e300..be4c7c7560c6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -1538,7 +1538,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - verify that schema file " + + test("transformWithState - verify that schema file " + "is kept after metadata is purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, From 3cc22bea648fd031ed60c50e8243fa2b8af2318d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 21 Nov 2024 11:58:59 -0800 Subject: [PATCH 4/9] moving to new line --- .../spark/sql/execution/streaming/state/RocksDBSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 5f3b62320a35c..30e96e15684b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -128,8 +128,10 @@ trait AlsoTestWithChangelogCheckpointingEnabled } } - def testWithEncodingTypes(testName: String, testTags: Tag*) - (testBody: => Any): Unit = { + def testWithEncodingTypes( + testName: String, + testTags: Tag*) + (testBody: => Any): Unit = { Seq("unsaferow", "avro").foreach { encoding => super.test(testName + s" (encoding = $encoding)", testTags: _*) { // in case tests have any code that needs to execute before every test From d45ea2a3594967191ecee504e0c5d02cf2dbf0a4 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 22 Nov 2024 01:24:06 -0800 Subject: [PATCH 5/9] burak feedback --- .../execution/streaming/StreamExecution.scala | 1 + .../streaming/state/RocksDBStateEncoder.scala | 39 ++++++++++++ .../state/RocksDBStateStoreProvider.scala | 44 +++++++++---- .../StateSchemaCompatibilityChecker.scala | 16 ++++- .../streaming/state/StateStore.scala | 3 +- .../streaming/state/ListStateSuite.scala | 14 ++--- .../streaming/state/MapStateSuite.scala | 12 ++-- .../state/RocksDBStateStoreSuite.scala | 60 ++++++++---------- .../streaming/state/RocksDBSuite.scala | 63 ++++--------------- .../streaming/state/TimerSuite.scala | 2 +- .../streaming/state/ValueStateSuite.scala | 28 ++++----- .../streaming/TransformWithStateSuite.scala | 3 +- 12 files changed, 157 insertions(+), 128 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bd501c9357234..44202bb0d2944 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -715,6 +715,7 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + val RUN_ID_KEY = "sql.streaming.runId" val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" val IO_EXCEPTION_NAMES = Seq( classOf[InterruptedException].getName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index f097a4bb265e3..d33083210c0e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -684,6 +684,26 @@ class RangeKeyScanStateEncoder( writer.getRow() } + /** + * Encodes an UnsafeRow into an Avro-compatible byte array format for range scan operations. + * + * This method transforms row data into a binary format that preserves ordering when + * used in range scans. + * For each field in the row: + * - A marker byte is written to indicate null status or sign (for numeric types) + * - The value is written in big-endian format + * + * Special handling is implemented for: + * - Null values: marked with nullValMarker followed by zero bytes + * - Negative numbers: marked with negativeValMarker + * - Floating point numbers: bit manipulation to handle sign and NaN values correctly + * + * @param row The UnsafeRow to encode + * @param avroType The Avro schema defining the structure for encoding + * @return Array[Byte] containing the Avro-encoded data that preserves ordering for range scans + * @throws UnsupportedOperationException if a field's data type is not supported for range + * scan encoding + */ def encodePrefixKeyForRangeScan( row: UnsafeRow, avroType: Schema): Array[Byte] = { @@ -805,6 +825,25 @@ class RangeKeyScanStateEncoder( out.toByteArray } + /** + * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan operations. + * + * This method reverses the encoding process performed by encodePrefixKeyForRangeScan: + * - Reads the marker byte to determine null status or sign + * - Reconstructs the original values from big-endian format + * - Handles special cases for floating point numbers by reversing bit manipulations + * + * The decoding process preserves the original data types and values, including: + * - Null values marked by nullValMarker + * - Sign information for numeric types + * - Proper restoration of negative floating point values + * + * @param bytes The Avro-encoded byte array to decode + * @param avroType The Avro schema defining the structure for decoding + * @return UnsafeRow containing the decoded data + * @throws UnsupportedOperationException if a field's data type is not supported for range + * scan decoding + */ def decodePrefixKeyForRangeScan( bytes: Array[Byte], avroType: Schema): UnsafeRow = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index bac7ba85ca4ec..479217f749663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.util.concurrent.ConcurrentHashMap +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.util.control.NonFatal +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -32,9 +34,9 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{NonFateSharingCache, Utils} private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable @@ -76,16 +78,17 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) // Create cache key using store ID to avoid collisions - val avroEncCacheKey = s"${stateStoreId.operatorId}_" + + val avroEncCacheKey = s"${getRunId}_${stateStoreId.operatorId}_" + s"${stateStoreId.partitionId}_$colFamilyName" - // If we have not created the avroEncoder for this column family, create - // it, or look in the cache maintained in the RocksDBStateStoreProvider - // companion object - lazy val avroEnc = stateStoreEncoding match { + def avroEnc = stateStoreEncoding match { case "avro" => Some( - RocksDBStateStoreProvider.avroEncoderMap.computeIfAbsent(avroEncCacheKey, - _ => getAvroEnc(keyStateEncoderSpec, valueSchema)) + RocksDBStateStoreProvider.avroEncoderMap.get( + avroEncCacheKey, + new java.util.concurrent.Callable[AvroEncoder] { + override def call(): AvroEncoder = getAvroEnc(keyStateEncoderSpec, valueSchema) + } + ) ) case "unsaferow" => None } @@ -95,6 +98,17 @@ private[sql] class RocksDBStateStoreProvider Some(newColFamilyId), avroEnc), RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey, avroEnc))) } + + private def getRunId: String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } + private def getAvroSerializer(schema: StructType): AvroSerializer = { val avroType = SchemaConverters.toAvroType(schema) new AvroSerializer(schema, avroType, nullable = false) @@ -657,8 +671,16 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 + private val MAX_AVRO_ENCODERS_IN_CACHE = 1000 // Add the cache at companion object level so it persists across provider instances - private val avroEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, AvroEncoder]() + private val avroEncoderMap: NonFateSharingCache[String, AvroEncoder] = { + val guavaCache = CacheBuilder.newBuilder() + .maximumSize(MAX_AVRO_ENCODERS_IN_CACHE) // Adjust size based on your needs + .expireAfterAccess(1, TimeUnit.HOURS) // Optional: Add expiration if needed + .build[String, AvroEncoder]() + + new NonFateSharingCache(guavaCache) + } // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 69a29cdbe7a17..6299c25353a97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -38,8 +38,20 @@ case class StateSchemaValidationResult( schemaPath: String ) -// Avro encoder that is used by the RocksDBStateStoreProvider and RocksDBStateEncoder -// in order to serialize from UnsafeRow to a byte array of Avro encoding. +/** An Avro-based encoder used for serializing between UnsafeRow and Avro + * byte arrays in RocksDB state stores. + * + * This encoder is primarily utilized by [[RocksDBStateStoreProvider]] and [[RocksDBStateEncoder]] + * to handle serialization and deserialization of state store data. + * + * @param keySerializer Serializer for converting state store keys to Avro format + * @param keyDeserializer Deserializer for converting Avro-encoded keys back to UnsafeRow + * @param valueSerializer Serializer for converting state store values to Avro format + * @param valueDeserializer Deserializer for converting Avro-encoded values back to UnsafeRow + * @param suffixKeySerializer Optional serializer for handling suffix keys in Avro format + * @param suffixKeyDeserializer Optional deserializer for converting Avro-encoded suffix + * keys back to UnsafeRow + */ case class AvroEncoder( keySerializer: AvroSerializer, keyDeserializer: AvroDeserializer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index c3449b6a12626..e2b93c147891d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamExecution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} @@ -781,6 +781,7 @@ object StateStore extends Logging { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey) storeProvider.getStore(version, stateStoreCkptId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index b20e6d7d49a9a..22876831c00d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -60,7 +60,7 @@ class ListStateSuite extends StateVariableSuiteBase { } Seq("appendList", "put").foreach { listImplFunc => - testWithEncodingTypes(s"Test list operation($listImplFunc) with null") { + test(s"Test list operation($listImplFunc) with null") { testMapStateWithNullUserKey() { listState => listImplFunc match { case "appendList" => listState.appendList(null) @@ -70,7 +70,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("List state operations for single instance") { + test("List state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -99,7 +99,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("List state operations for multiple instance") { + test("List state operations for multiple instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -139,7 +139,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("List state operations with list, value, another list instances") { + test("List state operations with list, value, another list instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -170,7 +170,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes(s"test List state TTL") { + test(s"test List state TTL") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -226,7 +226,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test null or negative TTL duration throws error") { + test("test null or negative TTL duration throws error") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 @@ -253,7 +253,7 @@ class ListStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("ListState TTL with non-primitive types") { + test("ListState TTL with non-primitive types") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index b913133deddeb..9a0a891d538ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -38,7 +38,7 @@ class MapStateSuite extends StateVariableSuiteBase { import testImplicits._ - testWithEncodingTypes("Map state operations for single instance") { + test("Map state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -73,7 +73,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Map state operations for multiple map instances") { + test("Map state operations for multiple map instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -113,7 +113,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Map state operations with list, value, another map instances") { + test("Map state operations with list, value, another map instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -173,7 +173,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test Map state TTL") { + test("test Map state TTL") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -231,7 +231,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test null or negative TTL duration throws error") { + test("test null or negative TTL duration throws error") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 @@ -259,7 +259,7 @@ class MapStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Map state with TTL with non-primitive types") { + test("Map state with TTL with non-primitive types") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index c895aced455e8..0abdcadefbd55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.util.Utils @ExtendedSQLTest class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider] with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes with SharedSparkSession with BeforeAndAfter { @@ -58,7 +59,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid import StateStoreTestsHelper._ - testWithColumnFamiliesAndEncodingTypes(s"version encoding", + testWithColumnFamilies(s"version encoding", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => import RocksDBStateStoreProvider._ @@ -127,7 +128,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb file manager metrics exposed", + testWithColumnFamilies("rocksdb file manager metrics exposed", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => import RocksDBStateStoreProvider._ def getCustomMetric(metrics: StateStoreMetrics, @@ -162,7 +163,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - invalid num columns", + testWithColumnFamilies("rocksdb range scan validation - invalid num columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => // zero ordering cols val ex1 = intercept[SparkUnsupportedOperationException] { @@ -201,7 +202,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - variable sized columns", + testWithColumnFamilies("rocksdb range scan validation - variable sized columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithVariableSizeCols: StructType = StructType( Seq(StructField("key1", StringType, false), StructField("key2", StringType, false))) @@ -224,8 +225,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - " + - "variable size data types unsupported", + testWithColumnFamilies("rocksdb range scan validation - variable size data types unsupported", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithSomeUnsupportedTypeCols: StructType = StructType(Seq( StructField("key1", StringType, false), @@ -265,7 +265,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan validation - null type columns", + testWithColumnFamilies("rocksdb range scan validation - null type columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithNullTypeCols: StructType = StructType( Seq(StructField("key1", NullType, false), StructField("key2", StringType, false))) @@ -288,7 +288,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - fixed size non-ordering columns", + testWithColumnFamilies("rocksdb range scan - fixed size non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -340,8 +340,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - variable " + - "size non-ordering columns with " + + testWithColumnFamilies("rocksdb range scan - variable size non-ordering columns with " + "double type values are supported", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -397,7 +396,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - variable size non-ordering columns", + testWithColumnFamilies("rocksdb range scan - variable size non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -450,8 +449,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + - "ordering columns - variable size " + + testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + s"non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -495,16 +493,15 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + - "non-contiguous ordering columns", + testWithColumnFamilies("rocksdb range scan multiple non-contiguous ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled ) { colFamiliesEnabled => val testSchema: StructType = StructType( Seq( - StructField("ordering1", LongType, false), + StructField("ordering-1", LongType, false), StructField("key2", StringType, false), - StructField("ordering2", IntegerType, false), - StructField("string2", StringType, false), - StructField("ordering3", DoubleType, false) + StructField("ordering-2", IntegerType, false), + StructField("string-2", StringType, false), + StructField("ordering-3", DoubleType, false) ) ) @@ -586,8 +583,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + - "ordering columns - variable size " + + testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + s"non-ordering columns with null values in first ordering column", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -687,8 +683,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " + - "ordering columns - variable size " + + testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + s"non-ordering columns with null values in second ordering column", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -741,8 +736,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan byte " + - "ordering column - variable size " + + testWithColumnFamilies("rocksdb range scan byte ordering column - variable size " + s"non-ordering columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -786,8 +780,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - ordering cols " + - "and key schema cols are same", + testWithColumnFamilies("rocksdb range scan - ordering cols and key schema cols are same", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => // use the same schema as value schema for single col key schema @@ -829,7 +822,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - with prefix scan", + testWithColumnFamilies("rocksdb range scan - with prefix scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -866,8 +859,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes("rocksdb key and value schema encoders " + - "for column families", + testWithColumnFamilies("rocksdb key and value schema encoders for column families", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val testColFamily = "testState" @@ -928,7 +920,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } /* Column family related tests */ - testWithColumnFamiliesAndEncodingTypes("column family creation with invalid names", + testWithColumnFamilies("column family creation with invalid names", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource( newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => @@ -965,7 +957,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes(s"column family creation with reserved chars", + testWithColumnFamilies(s"column family creation with reserved chars", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource( newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => @@ -1001,7 +993,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamiliesAndEncodingTypes(s"operations on absent column family", + testWithColumnFamilies(s"operations on absent column family", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource( newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => @@ -1154,7 +1146,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid Seq( NoPrefixKeyStateEncoderSpec(keySchema), PrefixKeyScanStateEncoderSpec(keySchema, 1) ).foreach { keyEncoder => - testWithColumnFamiliesAndEncodingTypes(s"validate rocksdb " + + testWithColumnFamilies(s"validate rocksdb " + s"${keyEncoder.getClass.toString.split('.').last} correctness", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchema, keyEncoder, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 30e96e15684b4..3a25fa73afc31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -86,6 +86,19 @@ trait RocksDBStateStoreChangelogCheckpointingTestUtil { } } +trait AlsoTestWithEncodingTypes extends SQLTestUtils { + override protected def test(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + Seq("unsaferow", "avro").foreach { encoding => + super.test(s"$testName (encoding = $encoding)", testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { + testBody + } + } + } + } +} + trait AlsoTestWithChangelogCheckpointingEnabled extends SQLTestUtils with RocksDBStateStoreChangelogCheckpointingTestUtil { @@ -147,56 +160,6 @@ trait AlsoTestWithChangelogCheckpointingEnabled } } - def testWithColumnFamiliesAndEncodingTypes( - testName: String, - testMode: TestMode, - testTags: Tag*) - (testBody: Boolean => Any): Unit = { - - Seq(true, false).foreach { colFamiliesEnabled => - Seq("unsaferow", "avro").foreach { encoding => - testMode match { - case TestWithChangelogCheckpointingEnabled => - testWithChangelogCheckpointingEnabled( - s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", - testTags: _*) { - withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { - testBody(colFamiliesEnabled) - } - } - - case TestWithChangelogCheckpointingDisabled => - testWithChangelogCheckpointingDisabled( - s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", - testTags: _*) { - withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { - testBody(colFamiliesEnabled) - } - } - - case TestWithBothChangelogCheckpointingEnabledAndDisabled => - testWithChangelogCheckpointingEnabled( - s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", - testTags: _*) { - withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { - testBody(colFamiliesEnabled) - } - } - testWithChangelogCheckpointingDisabled( - s"$testName - with colFamiliesEnabled=$colFamiliesEnabled (encoding = $encoding)", - testTags: _*) { - withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) { - testBody(colFamiliesEnabled) - } - } - - case _ => - throw new IllegalArgumentException(s"Unknown test mode: $testMode") - } - } - } - } - def testWithColumnFamilies( testName: String, testMode: TestMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala index be52691315259..24a120be9d9af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -30,7 +30,7 @@ class TimerSuite extends StateVariableSuiteBase { private def testWithTimeMode(testName: String) (testFunc: TimeMode => Unit): Unit = { Seq("Processing", "Event").foreach { timeoutMode => - testWithEncodingTypes(s"$timeoutMode timer - " + testName) { + test(s"$timeoutMode timer - " + testName) { timeoutMode match { case "Processing" => testFunc(TimeMode.ProcessingTime()) case "Event" => testFunc(TimeMode.EventTime()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 3d7b56a286b95..8984d9b0845b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -46,7 +46,7 @@ class ValueStateSuite extends StateVariableSuiteBase { import StateStoreTestsHelper._ import testImplicits._ - testWithEncodingTypes("Implicit key operations") { + test("Implicit key operations") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -91,7 +91,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Value state operations for single instance") { + test("Value state operations for single instance") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -118,7 +118,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Value state operations for multiple instances") { + test("Value state operations for multiple instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -163,7 +163,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Value state operations for unsupported type name should fail") { + test("Value state operations for unsupported type name should fail") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, @@ -184,7 +184,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("colFamily with HDFSBackedStateStoreProvider should fail") { + test("colFamily with HDFSBackedStateStoreProvider should fail") { val storeId = StateStoreId(newDir(), Random.nextInt(), 0) val provider = new HDFSBackedStateStoreProvider() val storeConf = new StateStoreConf(new SQLConf()) @@ -203,8 +203,7 @@ class ValueStateSuite extends StateVariableSuiteBase { ) } - testWithEncodingTypes("test SQL encoder - Value state operations for " + - "Primitive(Double) instances") { + test("test SQL encoder - Value state operations for Primitive(Double) instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -230,8 +229,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test SQL encoder - Value state operations " + - "for Primitive(Long) instances") { + test("test SQL encoder - Value state operations for Primitive(Long) instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -257,7 +255,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test SQL encoder - Value state operations for case class instances") { + test("test SQL encoder - Value state operations for case class instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -283,7 +281,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test SQL encoder - Value state operations for POJO instances") { + test("test SQL encoder - Value state operations for POJO instances") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), @@ -309,7 +307,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes(s"test Value state TTL") { + test(s"test Value state TTL") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -365,7 +363,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("test null or zero TTL duration throws error") { + test("test null or zero TTL duration throws error") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val batchTimestampMs = 10 @@ -392,7 +390,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } } - testWithEncodingTypes("Value state TTL with non-primitive type") { + test("Value state TTL with non-primitive type") { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val timestampMs = 10 @@ -425,7 +423,7 @@ class ValueStateSuite extends StateVariableSuiteBase { * types (ValueState, ListState, MapState) used in arbitrary stateful operators. */ abstract class StateVariableSuiteBase extends SharedSparkSession - with BeforeAndAfter with AlsoTestWithChangelogCheckpointingEnabled { + with BeforeAndAfter with AlsoTestWithEncodingTypes { before { StateStore.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d90e9a86e30ad..b5e2104d86251 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -780,7 +780,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - testWithEncodingTypes("Use statefulProcessor without transformWithState - handle should be absent") { + testWithEncodingTypes("Use statefulProcessor without transformWithState -" + + " handle should be absent") { val processor = new RunningCountStatefulProcessor() val ex = intercept[Exception] { processor.getHandle From feaca2073bd779abfae0609bbe8b1a841e1a6c57 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 22 Nov 2024 01:29:13 -0800 Subject: [PATCH 6/9] alsotestwith --- .../streaming/state/RocksDBSuite.scala | 19 ------- .../TransformWithListStateSuite.scala | 21 ++++---- .../TransformWithListStateTTLSuite.scala | 6 +-- .../TransformWithMapStateSuite.scala | 15 +++--- .../TransformWithMapStateTTLSuite.scala | 4 +- .../streaming/TransformWithStateSuite.scala | 54 +++++++++---------- .../streaming/TransformWithStateTTLTest.scala | 5 +- .../TransformWithValueStateTTLSuite.scala | 4 +- 8 files changed, 56 insertions(+), 72 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 3a25fa73afc31..61ca8e7c32f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -141,25 +141,6 @@ trait AlsoTestWithChangelogCheckpointingEnabled } } - def testWithEncodingTypes( - testName: String, - testTags: Tag*) - (testBody: => Any): Unit = { - Seq("unsaferow", "avro").foreach { encoding => - super.test(testName + s" (encoding = $encoding)", testTags: _*) { - // in case tests have any code that needs to execute before every test - // super.beforeEach() - withSQLConf( - SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> - encoding) { - testBody - } - // in case tests have any code that needs to execute after every test - // super.afterEach() - } - } - } - def testWithColumnFamilies( testName: String, testMode: TestMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 9f26de126e0dd..5d88db0d01ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputRow(key: String, action: String, value: String) @@ -127,10 +127,11 @@ class ToggleSaveAndEmitProcessor } class TransformWithListStateSuite extends StreamTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ - testWithEncodingTypes("test appending null value in list state throw exception") { + test("test appending null value in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -150,7 +151,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test putting null value in list state throw exception") { + test("test putting null value in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -170,7 +171,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test putting null list in list state throw exception") { + test("test putting null list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -190,7 +191,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test appending null list in list state throw exception") { + test("test appending null list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -210,7 +211,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test putting empty list in list state throw exception") { + test("test putting empty list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -230,7 +231,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test appending empty list in list state throw exception") { + test("test appending empty list in list state throw exception") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -250,7 +251,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test list state correctness") { + test("test list state correctness") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -307,7 +308,7 @@ class TransformWithListStateSuite extends StreamTest } } - testWithEncodingTypes("test ValueState And ListState in Processor") { + test("test ValueState And ListState in Processor") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index ebd29bff5d354..409a255ae3e64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -105,7 +105,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numListStateWithTTLVars" - testWithEncodingTypes("verify iterator works with expired values in beginning of list") { + test("verify iterator works with expired values in beginning of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -195,7 +195,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { // ascending order of TTL by stopping the query, setting the new TTL, and restarting // the query to check that the expired elements in the middle or end of the list // are not returned. - testWithEncodingTypes("verify iterator works with expired values in middle of list") { + test("verify iterator works with expired values in middle of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -343,7 +343,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { } } - testWithEncodingTypes("verify iterator works with expired values in end of list") { + test("verify iterator works with expired values in end of list") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 410ad55cad480..ec6ff4fcceb67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputMapRow(key: String, action: String, value: (String, String)) @@ -81,7 +81,8 @@ class TestMapStateProcessor * operators such as transformWithState. */ class TransformWithMapStateSuite extends StreamTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = { @@ -110,7 +111,7 @@ class TransformWithMapStateSuite extends StreamTest } } - testWithEncodingTypes("Test retrieving value with non-existing user key") { + test("Test retrieving value with non-existing user key") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -129,12 +130,12 @@ class TransformWithMapStateSuite extends StreamTest } Seq("getValue", "containsKey", "updateValue", "removeKey").foreach { mapImplFunc => - testWithEncodingTypes(s"Test $mapImplFunc with null user key") { + test(s"Test $mapImplFunc with null user key") { testMapStateWithNullUserKey(InputMapRow("k1", mapImplFunc, (null, ""))) } } - testWithEncodingTypes("Test put value with null value") { + test("Test put value with null value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -158,7 +159,7 @@ class TransformWithMapStateSuite extends StreamTest } } - testWithEncodingTypes("Test map state correctness") { + test("Test map state correctness") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputData = MemoryStream[InputMapRow] @@ -219,7 +220,7 @@ class TransformWithMapStateSuite extends StreamTest } } - testWithEncodingTypes("transformWithMapState - batch should succeed") { + test("transformWithMapState - batch should succeed") { val inputData = Seq( InputMapRow("k1", "updateValue", ("v1", "10")), InputMapRow("k1", "getValue", ("v1", ""))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index a68632534c001..022280eb3bcef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -182,7 +182,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numMapStateWithTTLVars" - testWithEncodingTypes("validate state is evicted with multiple user keys") { + test("validate state is evicted with multiple user keys") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -224,7 +224,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { } } - testWithEncodingTypes("verify iterator doesn't return expired keys") { + test("verify iterator doesn't return expired keys") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index b5e2104d86251..91a47645f4179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -429,11 +429,11 @@ class SleepingTimerProcessor extends StatefulProcessor[String, String, String] { * Class that adds tests for transformWithState stateful streaming operator */ class TransformWithStateSuite extends StateStoreMetricsTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled with AlsoTestWithEncodingTypes { import testImplicits._ - testWithEncodingTypes("transformWithState - streaming with rocksdb and" + + test("transformWithState - streaming with rocksdb and" + " invalid processor should fail") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -455,7 +455,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - lazy iterators can properly get/set keyed state") { + test("transformWithState - lazy iterators can properly get/set keyed state") { val spark = this.spark import spark.implicits._ @@ -533,7 +533,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - streaming with rocksdb should succeed") { + test("transformWithState - streaming with rocksdb should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -571,7 +571,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + + test("transformWithState - streaming with rocksdb and processing time timer " + "should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -616,7 +616,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + + test("transformWithState - streaming with rocksdb and processing time timer " + "and updating timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -652,7 +652,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - streaming with rocksdb and processing time timer " + + test("transformWithState - streaming with rocksdb and processing time timer " + "and multiple timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { @@ -689,7 +689,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - streaming with rocksdb and event " + + test("transformWithState - streaming with rocksdb and event " + "time based timer") { val inputData = MemoryStream[(String, Int)] val result = @@ -780,7 +780,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - testWithEncodingTypes("Use statefulProcessor without transformWithState -" + + test("Use statefulProcessor without transformWithState -" + " handle should be absent") { val processor = new RunningCountStatefulProcessor() val ex = intercept[Exception] { @@ -793,7 +793,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } - testWithEncodingTypes("transformWithState - batch should succeed") { + test("transformWithState - batch should succeed") { val inputData = Seq("a", "b") val result = inputData.toDS() .groupByKey(x => x) @@ -805,7 +805,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) } - testWithEncodingTypes("transformWithState - test deleteIfExists operator") { + test("transformWithState - test deleteIfExists operator") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -846,7 +846,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - two input streams") { + test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -876,7 +876,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - three input streams") { + test("transformWithState - three input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -911,7 +911,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - two input streams, different key type") { + test("transformWithState - two input streams, different key type") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -958,7 +958,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest OutputMode.Update()) } - testWithEncodingTypes("transformWithState - availableNow trigger mode, rate limit is respected") { + test("transformWithState - availableNow trigger mode, rate limit is respected") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -999,7 +999,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - availableNow trigger mode, multiple restarts") { + test("transformWithState - availableNow trigger mode, multiple restarts") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -1037,7 +1037,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - verify StateSchemaV3 writes " + + test("transformWithState - verify StateSchemaV3 writes " + "correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -1120,7 +1120,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("transformWithState - verify that OperatorStateMetadataV2" + + test("transformWithState - verify that OperatorStateMetadataV2" + " file is being written correctly") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -1164,7 +1164,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test that invalid schema evolution fails query for column family") { + test("test that invalid schema evolution fails query for column family") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1201,7 +1201,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test that different outputMode after query restart fails") { + test("test that different outputMode after query restart fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1244,7 +1244,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test that changing between different state variable types fails") { + test("test that changing between different state variable types fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1286,7 +1286,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test that different timeMode after query restart fails") { + test("test that different timeMode after query restart fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1333,7 +1333,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test that introducing TTL after restart fails query") { + test("test that introducing TTL after restart fails query") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1387,7 +1387,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test query restart with new state variable succeeds") { + test("test query restart with new state variable succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1433,7 +1433,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("test query restart succeeds") { + test("test query restart succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1468,7 +1468,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncodingTypes("SPARK-49070: transformWithState - valid initial state plan") { + test("SPARK-49070: transformWithState - valid initial state plan") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => @@ -1518,7 +1518,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest new Path(stateCheckpointPath, "_stateSchema/default/") } - testWithEncodingTypes("transformWithState - verify that metadata and schema logs are purged") { + test("transformWithState - verify that metadata and schema logs are purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 92a06f1183935..75fda9630779e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,7 +41,8 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled { + extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled + with AlsoTestWithEncodingTypes { import testImplicits._ def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, InputEvent, OutputEvent] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 9727c2dc8c113..b19c126c7386b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -195,7 +195,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numValueStateWithTTLVars" - testWithEncodingTypes("validate multiple value states") { + test("validate multiple value states") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val ttlKey = "k1" @@ -262,7 +262,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { } } - testWithEncodingTypes("verify StateSchemaV3 writes correct SQL " + + test("verify StateSchemaV3 writes correct SQL " + "schema of key/value and with TTL") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, From 02eaeb783771b2892fda7687f1f417b1beca3be1 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 22 Nov 2024 10:39:24 -0800 Subject: [PATCH 7/9] scaladoc --- .../streaming/state/StateSchemaCompatibilityChecker.scala | 3 ++- .../v2/state/StateDataSourceTransformWithStateSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 6299c25353a97..48b15ac04f40b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -38,7 +38,8 @@ case class StateSchemaValidationResult( schemaPath: String ) -/** An Avro-based encoder used for serializing between UnsafeRow and Avro +/** + * An Avro-based encoder used for serializing between UnsafeRow and Avro * byte arrays in RocksDB state stores. * * This encoder is primarily utilized by [[RocksDBStateStoreProvider]] and [[RocksDBStateEncoder]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index baab6327b35c1..af64f563cf7b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} import org.apache.spark.sql.functions.{col, explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} @@ -126,7 +126,7 @@ class SessionGroupsStatefulProcessorWithTTL extends * Test suite to verify integration of state data source reader with the transformWithState operator */ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest - with AlsoTestWithChangelogCheckpointingEnabled { + with AlsoTestWithChangelogCheckpointingEnabled with AlsoTestWithEncodingTypes { import testImplicits._ From 3b753d103b4d1b8aaef8bd99264b4a057b6a2fdb Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 22 Nov 2024 14:47:19 -0800 Subject: [PATCH 8/9] works with state data source --- .../sql/execution/streaming/state/RocksDBStateEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index d33083210c0e3..f39022c1f53a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -1105,7 +1105,7 @@ class NoPrefixKeyStateEncoder( val avroBytes = new Array[Byte](dataLength) Platform.copyMemory( keyBytes, decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES, - avroBytes, 0, dataLength) + avroBytes, Platform.BYTE_ARRAY_OFFSET, dataLength) decodeFromAvroToUnsafeRow(avroBytes, avroEnc.get.keyDeserializer, keyAvroType, keyProj) } else { keyRow.pointTo( From f22bcbf4a96c5f4992c23a4aa309e18b7f42db55 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 22 Nov 2024 15:05:37 -0800 Subject: [PATCH 9/9] tests pass --- .../state/RocksDBStateStoreProvider.scala | 159 ++++++++++-------- 1 file changed, 88 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 479217f749663..e5a4175aeec1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -78,20 +78,11 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) // Create cache key using store ID to avoid collisions - val avroEncCacheKey = s"${getRunId}_${stateStoreId.operatorId}_" + + val avroEncCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + s"${stateStoreId.partitionId}_$colFamilyName" - def avroEnc = stateStoreEncoding match { - case "avro" => Some( - RocksDBStateStoreProvider.avroEncoderMap.get( - avroEncCacheKey, - new java.util.concurrent.Callable[AvroEncoder] { - override def call(): AvroEncoder = getAvroEnc(keyStateEncoderSpec, valueSchema) - } - ) - ) - case "unsaferow" => None - } + val avroEnc = getAvroEnc( + stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema) keyValueEncoderMap.putIfAbsent(colFamilyName, (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies, @@ -99,63 +90,6 @@ private[sql] class RocksDBStateStoreProvider useMultipleValuesPerKey, avroEnc))) } - private def getRunId: String = { - val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) - if (runId != null) { - runId - } else { - assert(Utils.isTesting, "Failed to find query id/batch Id in task context") - UUID.randomUUID().toString - } - } - - private def getAvroSerializer(schema: StructType): AvroSerializer = { - val avroType = SchemaConverters.toAvroType(schema) - new AvroSerializer(schema, avroType, nullable = false) - } - - private def getAvroDeserializer(schema: StructType): AvroDeserializer = { - val avroType = SchemaConverters.toAvroType(schema) - val avroOptions = AvroOptions(Map.empty) - new AvroDeserializer(avroType, schema, - avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) - } - - private def getAvroEnc( - keyStateEncoderSpec: KeyStateEncoderSpec, - valueSchema: StructType - ): AvroEncoder = { - val valueSerializer = getAvroSerializer(valueSchema) - val valueDeserializer = getAvroDeserializer(valueSchema) - val keySchema = keyStateEncoderSpec match { - case NoPrefixKeyStateEncoderSpec(schema) => - schema - case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => - StructType(schema.take(numColsPrefixKey)) - case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => - val remainingSchema = { - 0.until(schema.length).diff(orderingOrdinals).map { ordinal => - schema(ordinal) - } - } - StructType(remainingSchema) - } - val suffixKeySchema = keyStateEncoderSpec match { - case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => - Some(StructType(schema.drop(numColsPrefixKey))) - case _ => None - } - AvroEncoder( - getAvroSerializer(keySchema), - getAvroDeserializer(keySchema), - valueSerializer, - valueDeserializer, - suffixKeySchema.map(getAvroSerializer), - suffixKeySchema.map(getAvroDeserializer) - ) - } - override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { verify(key != null, "Key cannot be null") verifyColFamilyOperations("get", colFamilyName) @@ -454,10 +388,17 @@ private[sql] class RocksDBStateStoreProvider defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME)) } + val colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME + // Create cache key using store ID to avoid collisions + val avroEncCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + + s"${stateStoreId.partitionId}_$colFamilyName" + val avroEnc = getAvroEnc( + stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema) + keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, - useColumnFamilies, defaultColFamilyId), - RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey))) + useColumnFamilies, defaultColFamilyId, avroEnc), + RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey, avroEnc))) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -682,6 +623,82 @@ object RocksDBStateStoreProvider { new NonFateSharingCache(guavaCache) } + def getAvroEnc( + stateStoreEncoding: String, + avroEncCacheKey: String, + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType): Option[AvroEncoder] = { + + stateStoreEncoding match { + case "avro" => Some( + RocksDBStateStoreProvider.avroEncoderMap.get( + avroEncCacheKey, + new java.util.concurrent.Callable[AvroEncoder] { + override def call(): AvroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) + } + ) + ) + case "unsaferow" => None + } + } + + private def getRunId(hadoopConf: Configuration): String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } + + private def getAvroSerializer(schema: StructType): AvroSerializer = { + val avroType = SchemaConverters.toAvroType(schema) + new AvroSerializer(schema, avroType, nullable = false) + } + + private def getAvroDeserializer(schema: StructType): AvroDeserializer = { + val avroType = SchemaConverters.toAvroType(schema) + val avroOptions = AvroOptions(Map.empty) + new AvroDeserializer(avroType, schema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + } + + private def createAvroEnc( + keyStateEncoderSpec: KeyStateEncoderSpec, + valueSchema: StructType + ): AvroEncoder = { + val valueSerializer = getAvroSerializer(valueSchema) + val valueDeserializer = getAvroDeserializer(valueSchema) + val keySchema = keyStateEncoderSpec match { + case NoPrefixKeyStateEncoderSpec(schema) => + schema + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + StructType(schema.take(numColsPrefixKey)) + case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) => + val remainingSchema = { + 0.until(schema.length).diff(orderingOrdinals).map { ordinal => + schema(ordinal) + } + } + StructType(remainingSchema) + } + val suffixKeySchema = keyStateEncoderSpec match { + case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) => + Some(StructType(schema.drop(numColsPrefixKey))) + case _ => None + } + AvroEncoder( + getAvroSerializer(keySchema), + getAvroDeserializer(keySchema), + valueSerializer, + valueDeserializer, + suffixKeySchema.map(getAvroSerializer), + suffixKeySchema.map(getAvroDeserializer) + ) + } + // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(