From eee629be856740737b592fb7f201a12c7c2a6b99 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 28 Oct 2024 15:44:09 -0700 Subject: [PATCH 1/3] init --- .../execution/streaming/ListStateImpl.scala | 6 +- .../execution/streaming/MapStateImpl.scala | 8 +- .../StateStoreColumnFamilySchemaUtils.scala | 32 +- .../streaming/StateTypesEncoderUtils.scala | 21 +- .../StatefulProcessorHandleImpl.scala | 4 + .../streaming/state/RocksDBStateEncoder.scala | 339 +++++++++++++++++- .../state/RocksDBStateStoreSuite.scala | 42 +++ 7 files changed, 435 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 50294fa3d0587..0c4f900e83664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -212,9 +212,9 @@ class ListStateImpl[S]( if (usingAvro) { val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey() store.remove(encodedKey, stateName) - val entryCount = getEntryCount(encodedKey) - TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) - removeEntryCount(encodedKey) +// val entryCount = getEntryCount(encodedKey) +// TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) +// removeEntryCount(encodedKey) } else { val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey() store.remove(encodedKey, stateName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index f8dac08e0650c..73cf6813c7966 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -42,7 +42,7 @@ class MapStateImpl[K, V]( keyExprEnc: ExpressionEncoder[Any], userKeyEnc: Encoder[K], valEncoder: Encoder[V], - avroSerde: Option[AvroEncoderSpec], + avroEnc: Option[AvroEncoderSpec], metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging { // Pack grouping key and user key together as a prefixed composite key @@ -53,9 +53,9 @@ class MapStateImpl[K, V]( // If we are using Avro, the avroSerde parameter must be populated // else, we will default to using UnsafeRow. - private val usingAvro: Boolean = avroSerde.isDefined + private val usingAvro: Boolean = avroEnc.isDefined private val avroTypesEncoder = new CompositeKeyAvroRowEncoder( - keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = false, avroSerde) + keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = false, avroEnc) private val unsafeRowTypesEncoder = new CompositeKeyUnsafeRowEncoder( keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = false) @@ -163,4 +163,4 @@ class MapStateImpl[K, V]( removeKey(key) } } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index d120db8a68db6..713d1b99beb03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -20,9 +20,8 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.execution.streaming.state.AvroEncoderSpec -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema} +import org.apache.spark.sql.types.{NullType, StructField, StructType} object StateStoreColumnFamilySchemaUtils { @@ -61,6 +60,33 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) { } } + def getTtlStateSchema[T]( + stateName: String, + keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = { + val ttlKeySchema = getSingleKeyTTLRowSchema(keyEncoder.schema) + val ttlValSchema = StructType(Array(StructField("__dummy__", NullType))) + StateStoreColFamilySchema( + stateName, + ttlKeySchema, + ttlValSchema, + Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), + avroEnc = getAvroSerde(keyEncoder.schema, ttlValSchema)) + } + + def getTtlStateSchema[T]( + stateName: String, + keyEncoder: ExpressionEncoder[Any], + userKeyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = { + val ttlKeySchema = getCompositeKeyTTLRowSchema(keyEncoder.schema, userKeyEncoder.schema) + val ttlValSchema = StructType(Array(StructField("__dummy__", NullType))) + StateStoreColFamilySchema( + stateName, + ttlKeySchema, + ttlValSchema, + Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), + avroEnc = getAvroSerde(keyEncoder.schema, ttlValSchema)) + } + def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index bcab3fdfa967c..5ebcc0d79c8a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -436,15 +436,34 @@ class CompositeKeyAvroRowEncoder[K, V]( /** Class for TTL with single key serialization */ class SingleKeyTTLEncoder( - keyExprEnc: ExpressionEncoder[Any]) { + keyExprEnc: ExpressionEncoder[Any], + avroEncoderSpec: Option[AvroEncoderSpec] = None) { + private lazy val out = new ByteArrayOutputStream private val ttlKeyProjection = UnsafeProjection.create( getSingleKeyTTLRowSchema(keyExprEnc.schema)) + private val ttlKeyType = SchemaConverters.toAvroType(keyExprEnc.schema) + def encodeTTLRow(expirationMs: Long, groupingKey: UnsafeRow): UnsafeRow = { ttlKeyProjection.apply( InternalRow(expirationMs, groupingKey.asInstanceOf[InternalRow])) } + + def encodeTTLRow(expirationMs: Long, groupingKey: Array[Byte]): Array[Byte] = { + val objRow: InternalRow = InternalRow(expirationMs, groupingKey.asInstanceOf[InternalRow]) + val avroData = + avroEncoderSpec.get.valueSerializer.serialize(objRow) // InternalRow -> GenericDataRecord + out.reset() + + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + ttlKeyType) // Defining Avro writer for this struct type + + writer.write(avroData, encoder) // GenericDataRecord -> bytes + encoder.flush() + out.toByteArray + } } /** Class for TTL with composite key serialization */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 88f0be0b2269f..96c67fb0c7385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -399,6 +399,8 @@ class DriverStatefulProcessorHandleImpl( getValueStateSchema(stateName, keyExprEnc, valEncoder, true) checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val ttlColFamilySchema = schemaUtils.getTtlStateSchema("$ttl_" + stateName, keyExprEnc) + columnFamilySchemas.put("$ttl_" + stateName, ttlColFamilySchema) val stateVariableInfo = TransformWithStateVariableUtils. getValueState(stateName, ttlEnabled = true) stateVariableInfos.put(stateName, stateVariableInfo) @@ -426,6 +428,8 @@ class DriverStatefulProcessorHandleImpl( getListStateSchema(stateName, keyExprEnc, valEncoder, true) checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val ttlColFamilySchema = schemaUtils.getTtlStateSchema("$ttl_" + stateName, keyExprEnc) + columnFamilySchemas.put("$ttl_" + stateName, ttlColFamilySchema) val stateVariableInfo = TransformWithStateVariableUtils. getListState(stateName, ttlEnabled = true) stateVariableInfos.put(stateName, stateVariableInfo) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index fc7bbf9ddfb0a..ba3036bbd652f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -729,15 +729,342 @@ class RangeKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true + override def encodePrefixKeyBytes(prefixKeyBytes: Array[Byte]): Array[Byte] = { + // For range scan prefix key, we need to transform the bytes to maintain ordering + // This requires parsing each field according to its type and applying the range-scan encoding + + var pos = 0 + var totalSize = 0 + + // Calculate total size needed - each field needs marker byte + field size + rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => + totalSize += 1 // marker byte + val fieldSize = field.dataType match { + case BooleanType | ByteType => 1 + case ShortType => 2 + case IntegerType | FloatType => 4 + case LongType | DoubleType => 8 + case _ => throw new IllegalArgumentException(s"Unsupported type ${field.dataType}") + } + totalSize += fieldSize + } + + // Create byte array for the encoded fields + val encodedFields = new Array[Byte](totalSize) + var encodedPos = 0 + + // Process each field + rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => + val valueSize = field.dataType match { + case BooleanType | ByteType => 1 + case ShortType => 2 + case IntegerType | FloatType => 4 + case LongType | DoubleType => 8 + } + + // Check if we have enough bytes left to read the value + if (pos + valueSize > prefixKeyBytes.length) { + throw new IllegalArgumentException(s"Not enough bytes to read ${field.dataType}") + } + + // Create buffer for reading the value in native byte order + val bbuf = ByteBuffer.wrap(prefixKeyBytes, pos, valueSize) + val bbufBE = ByteBuffer.allocate(valueSize) + bbufBE.order(ByteOrder.BIG_ENDIAN) + + field.dataType match { + case ByteType => + val value = bbuf.get() + encodedFields(encodedPos) = if (value == 0) { + 0x02.toByte // null marker + } else if (value < 0) { + 0x00.toByte // negative marker + } else { + 0x01.toByte // positive marker + } + encodedPos += 1 + bbufBE.put(value) + bbufBE.flip() + bbufBE.get(encodedFields, encodedPos, valueSize) + + case ShortType => + val value = bbuf.getShort() + encodedFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.putShort(value) + bbufBE.flip() + bbufBE.get(encodedFields, encodedPos, valueSize) + + case IntegerType => + val value = bbuf.getInt() + encodedFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.putInt(value) + bbufBE.flip() + bbufBE.get(encodedFields, encodedPos, valueSize) + + case LongType => + val value = bbuf.getLong() + encodedFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.putLong(value) + bbufBE.flip() + bbufBE.get(encodedFields, encodedPos, valueSize) + + case FloatType => + val value = bbuf.getFloat() + val bits = java.lang.Float.floatToRawIntBits(value) + encodedFields(encodedPos) = if (bits == 0) { + 0x02.toByte + } else if ((bits & floatSignBitMask) != 0) { + // Negative float - need to flip bits for correct ordering + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + val floatToWrite = if ((bits & floatSignBitMask) != 0) { + java.lang.Float.intBitsToFloat(bits ^ floatFlipBitMask) + } else value + bbufBE.putFloat(floatToWrite) + bbufBE.flip() + bbufBE.get(encodedFields, encodedPos, valueSize) + + case DoubleType => + val value = bbuf.getDouble() + val bits = java.lang.Double.doubleToRawLongBits(value) + encodedFields(encodedPos) = if (bits == 0) { + 0x02.toByte + } else if ((bits & doubleSignBitMask) != 0) { + // Negative double - need to flip bits for correct ordering + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + val doubleToWrite = if ((bits & doubleSignBitMask) != 0) { + java.lang.Double.longBitsToDouble(bits ^ doubleFlipBitMask) + } else value + bbufBE.putDouble(doubleToWrite) + bbufBE.flip() + bbufBE.get(encodedFields, encodedPos, valueSize) + } + + encodedPos += valueSize + pos += valueSize + } + + // Now wrap these encoded fields with version and column family prefix if needed + val (result, startingOffset) = encodeColumnFamilyPrefix(encodedFields.length + 4) + Platform.putInt(result, startingOffset, encodedFields.length) + Platform.copyMemory(encodedFields, Platform.BYTE_ARRAY_OFFSET, + result, startingOffset + 4, encodedFields.length) + + result + } + + override def encodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = { + // For full key encoding, we need to: + // 1. Split the input bytes into range scan fields and remaining fields + // 2. Encode the range scan fields using the same logic as encodePrefixKeyBytes + // 3. Append the remaining fields + // 4. Add column family prefix if needed + + var pos = 0 + var totalSize = 0 + + // Calculate size needed for range scan fields + rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => + totalSize += 1 // marker byte + val fieldSize = field.dataType match { + case BooleanType | ByteType => 1 + case ShortType => 2 + case IntegerType | FloatType => 4 + case LongType | DoubleType => 8 + case _ => throw new IllegalArgumentException(s"Unsupported type ${field.dataType}") + } + totalSize += fieldSize + } + + // Create array for range scan encoded fields + val encodedRangeFields = new Array[Byte](totalSize) + var encodedPos = 0 + + // Encode range scan fields + rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => + val valueSize = field.dataType match { + case BooleanType | ByteType => 1 + case ShortType => 2 + case IntegerType | FloatType => 4 + case LongType | DoubleType => 8 + } + + if (pos + valueSize > keyBytes.length) { + throw new IllegalArgumentException(s"Not enough bytes to read ${field.dataType}") + } + + val bbuf = ByteBuffer.wrap(keyBytes, pos, valueSize) + val bbufBE = ByteBuffer.allocate(valueSize) + bbufBE.order(ByteOrder.BIG_ENDIAN) + + field.dataType match { + case ByteType => + val value = bbuf.get() + encodedRangeFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.put(value) + bbufBE.flip() + bbufBE.get(encodedRangeFields, encodedPos, valueSize) + + case ShortType => + val value = bbuf.getShort() + encodedRangeFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.putShort(value) + bbufBE.flip() + bbufBE.get(encodedRangeFields, encodedPos, valueSize) + + case IntegerType => + val value = bbuf.getInt() + encodedRangeFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.putInt(value) + bbufBE.flip() + bbufBE.get(encodedRangeFields, encodedPos, valueSize) + + case LongType => + val value = bbuf.getLong() + encodedRangeFields(encodedPos) = if (value == 0) { + 0x02.toByte + } else if (value < 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + bbufBE.putLong(value) + bbufBE.flip() + bbufBE.get(encodedRangeFields, encodedPos, valueSize) + + case FloatType => + val value = bbuf.getFloat() + val bits = java.lang.Float.floatToRawIntBits(value) + encodedRangeFields(encodedPos) = if (bits == 0) { + 0x02.toByte + } else if ((bits & floatSignBitMask) != 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + val floatToWrite = if ((bits & floatSignBitMask) != 0) { + java.lang.Float.intBitsToFloat(bits ^ floatFlipBitMask) + } else value + bbufBE.putFloat(floatToWrite) + bbufBE.flip() + bbufBE.get(encodedRangeFields, encodedPos, valueSize) + + case DoubleType => + val value = bbuf.getDouble() + val bits = java.lang.Double.doubleToRawLongBits(value) + encodedRangeFields(encodedPos) = if (bits == 0) { + 0x02.toByte + } else if ((bits & doubleSignBitMask) != 0) { + 0x00.toByte + } else { + 0x01.toByte + } + encodedPos += 1 + val doubleToWrite = if ((bits & doubleSignBitMask) != 0) { + java.lang.Double.longBitsToDouble(bits ^ doubleFlipBitMask) + } else value + bbufBE.putDouble(doubleToWrite) + bbufBE.flip() + bbufBE.get(encodedRangeFields, encodedPos, valueSize) + } + + encodedPos += valueSize + pos += valueSize + } + + // Get remaining bytes + val remainingBytes = if (pos < keyBytes.length) { + val remaining = new Array[Byte](keyBytes.length - pos) + Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + pos, + remaining, Platform.BYTE_ARRAY_OFFSET, remaining.length) + remaining + } else Array.empty[Byte] + + // Create final result with version and column family prefix if needed + val (result, startingOffset) = encodeColumnFamilyPrefix( + 4 + encodedRangeFields.length + remainingBytes.length) + + Platform.putInt(result, startingOffset, encodedRangeFields.length) + Platform.copyMemory(encodedRangeFields, Platform.BYTE_ARRAY_OFFSET, + result, startingOffset + 4, encodedRangeFields.length) + + if (remainingBytes.nonEmpty) { + Platform.copyMemory(remainingBytes, Platform.BYTE_ARRAY_OFFSET, + result, startingOffset + 4 + encodedRangeFields.length, remainingBytes.length) + } - override def encodePrefixKeyBytes(prefixKey: Array[Byte]): Array[Byte] = - throw new UnsupportedOperationException + result + } + + override def decodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = { + if (keyBytes == null) return null + + // Skip column family prefix if present + val startOffset = decodeKeyStartOffset - override def encodeKeyBytes(row: Array[Byte]): Array[Byte] = - throw new UnsupportedOperationException + // Read prefix length + val prefixLen = Platform.getInt(keyBytes, startOffset) - override def decodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = - throw new UnsupportedOperationException + // Total length of data we want to return + val totalLen = keyBytes.length - 4 - offsetForColFamilyPrefix + + val result = new Array[Byte](totalLen) + Platform.copyMemory(keyBytes, startOffset + 4, + result, Platform.BYTE_ARRAY_OFFSET, totalLen) + + result + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e1bd9dd38066b..318349a9a434c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -821,6 +821,48 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + testWithColumnFamilies("rocksdb range scan - ordering cols and key schema cols are same", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + // use the same schema as value schema for single col key schema + tryWithProviderResource(newStoreProvider(valueSchema, + RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + valueSchema, valueSchema, + RangeKeyScanStateEncoderSpec(valueSchema, Seq(0))) + } + + val timerTimestamps = Seq(931, 8000, 452300, 4200, + -3545, -343, 133, -90, -8014490, -79247, + 90, 1, 2, 8, 3, 35, 6, 9, 5, -233) + timerTimestamps.foreach { ts => + // non-timestamp col is of variable size + val keyRow = dataToValueRow(ts) + val valueRow = dataToValueRow(1) + store.put(keyRow, valueRow, cfName) + assert(valueRowToData(store.get(keyRow, cfName)) === 1) + } + + val result = store.iterator(cfName).map { kv => + valueRowToData(kv.key) + }.toSeq + assert(result === timerTimestamps.sorted) + + // also check for prefix scan + timerTimestamps.foreach { ts => + val prefix = dataToValueRow(ts) + val result = store.prefixScan(prefix, cfName).map { kv => + assert(valueRowToData(kv.value) === 1) + valueRowToData(kv.key) + }.toSeq + assert(result.size === 1) + } + } + } + testWithColumnFamilies("rocksdb range scan - with prefix scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => From 2a9e5f32a26f6e549fb2444b78719186ee9a5e6d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 29 Oct 2024 14:00:34 -0700 Subject: [PATCH 2/3] ttl tests pass --- .../streaming/ListStateImplWithTTL.scala | 43 +++--- .../StateStoreColumnFamilySchemaUtils.scala | 16 +- .../streaming/StateTypesEncoderUtils.scala | 52 +++++-- .../StatefulProcessorHandleImpl.scala | 6 +- .../sql/execution/streaming/TTLState.scala | 129 ++++++++++++++-- .../streaming/ValueStateImplWithTTL.scala | 142 ++++++++++++++---- .../streaming/state/RocksDBStateEncoder.scala | 13 +- .../StateSchemaCompatibilityChecker.scala | 1 + .../streaming/TransformWithStateTTLTest.scala | 12 +- .../TransformWithValueStateTTLSuite.scala | 5 +- 10 files changed, 325 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 639683e5ff549..cec4a02bd1842 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -47,8 +47,9 @@ class ListStateImplWithTTL[S]( ttlConfig: TTLConfig, batchTimestampMs: Long, avroEnc: Option[AvroEncoderSpec], // TODO: Add Avro Encoding support for TTL + ttlAvroEnc: Option[AvroEncoderSpec], metrics: Map[String, SQLMetric] = Map.empty) - extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs) + extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs, ttlAvroEnc) with ListStateMetricsImpl with ListState[S] { @@ -56,7 +57,7 @@ class ListStateImplWithTTL[S]( override def baseStateName: String = stateName override def exprEncSchema: StructType = keyExprEnc.schema - private lazy val stateTypesEncoder = UnsafeRowTypesEncoder(keyExprEnc, valEncoder, + private lazy val unsafeRowTypesEncoder = UnsafeRowTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) private lazy val ttlExpirationMs = @@ -80,19 +81,19 @@ class ListStateImplWithTTL[S]( * empty iterator is returned. */ override def get(): Iterator[S] = { - val encodedKey = stateTypesEncoder.encodeGroupingKey() + val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey() val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName) new NextIterator[S] { override protected def getNext(): S = { val iter = unsafeRowValuesIterator.dropWhile { row => - stateTypesEncoder.isExpired(row, batchTimestampMs) + unsafeRowTypesEncoder.isExpired(row, batchTimestampMs) } if (iter.hasNext) { val currentRow = iter.next() - stateTypesEncoder.decodeValue(currentRow) + unsafeRowTypesEncoder.decodeValue(currentRow) } else { finished = true null.asInstanceOf[S] @@ -107,13 +108,13 @@ class ListStateImplWithTTL[S]( override def put(newState: Array[S]): Unit = { validateNewState(newState) - val encodedKey = stateTypesEncoder.encodeGroupingKey() + val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey() var isFirst = true var entryCount = 0L TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows") newState.foreach { v => - val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs) + val encodedValue = unsafeRowTypesEncoder.encodeValue(v, ttlExpirationMs) if (isFirst) { store.put(encodedKey, encodedValue, stateName) isFirst = false @@ -130,10 +131,10 @@ class ListStateImplWithTTL[S]( /** Append an entry to the list. */ override def appendValue(newState: S): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) - val encodedKey = stateTypesEncoder.encodeGroupingKey() + val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey() val entryCount = getEntryCount(encodedKey) store.merge(encodedKey, - stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName) + unsafeRowTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName) TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(encodedKey) updateEntryCount(encodedKey, entryCount + 1) @@ -143,10 +144,10 @@ class ListStateImplWithTTL[S]( override def appendList(newState: Array[S]): Unit = { validateNewState(newState) - val encodedKey = stateTypesEncoder.encodeGroupingKey() + val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey() var entryCount = getEntryCount(encodedKey) newState.foreach { v => - val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs) + val encodedValue = unsafeRowTypesEncoder.encodeValue(v, ttlExpirationMs) store.merge(encodedKey, encodedValue, stateName) entryCount += 1 TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") @@ -157,7 +158,7 @@ class ListStateImplWithTTL[S]( /** Remove this state. */ override def clear(): Unit = { - val encodedKey = stateTypesEncoder.encodeGroupingKey() + val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey() store.remove(encodedKey, stateName) val entryCount = getEntryCount(encodedKey) TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) @@ -188,7 +189,7 @@ class ListStateImplWithTTL[S]( var isFirst = true var entryCount = 0L unsafeRowValuesIterator.foreach { encodedValue => - if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) { + if (!unsafeRowTypesEncoder.isExpired(encodedValue, batchTimestampMs)) { if (isFirst) { store.put(groupingKey, encodedValue, stateName) isFirst = false @@ -205,6 +206,10 @@ class ListStateImplWithTTL[S]( numValuesExpired } + override def clearIfExpired(groupingKey: Array[Byte]): Long = { + 0 + } + private def upsertTTLForStateKey(encodedGroupingKey: UnsafeRow): Unit = { upsertTTLForStateKey(ttlExpirationMs, encodedGroupingKey) } @@ -221,10 +226,10 @@ class ListStateImplWithTTL[S]( * end of the micro-batch. */ private[sql] def getWithoutEnforcingTTL(): Iterator[S] = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey() val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName) unsafeRowValuesIterator.map { valueUnsafeRow => - stateTypesEncoder.decodeValue(valueUnsafeRow) + unsafeRowTypesEncoder.decodeValue(valueUnsafeRow) } } @@ -232,11 +237,11 @@ class ListStateImplWithTTL[S]( * Read the ttl value associated with the grouping key. */ private[sql] def getTTLValues(): Iterator[(S, Long)] = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey() val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName) unsafeRowValuesIterator.map { valueUnsafeRow => - (stateTypesEncoder.decodeValue(valueUnsafeRow), - stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get) + (unsafeRowTypesEncoder.decodeValue(valueUnsafeRow), + unsafeRowTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get) } } @@ -245,6 +250,6 @@ class ListStateImplWithTTL[S]( * grouping key. */ private[sql] def getValuesInTTLState(): Iterator[Long] = { - getValuesInTTLState(stateTypesEncoder.encodeGroupingKey()) + getValuesInTTLState(unsafeRowTypesEncoder.encodeGroupingKey()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 713d1b99beb03..6bd6c959746f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -16,12 +16,13 @@ */ package org.apache.spark.sql.execution.streaming +import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.types.{NullType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, NullType, StructField, StructType} object StateStoreColumnFamilySchemaUtils { @@ -29,7 +30,7 @@ object StateStoreColumnFamilySchemaUtils { new StateStoreColumnFamilySchemaUtils(initializeAvroSerde) } -class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) { +class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Logging { private def getAvroSerde( keySchema: StructType, @@ -40,6 +41,9 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) { val avroOptions = AvroOptions(Map.empty) val keyAvroType = SchemaConverters.toAvroType(keySchema) val keySer = new AvroSerializer(keySchema, keyAvroType, nullable = false) + val keyDe = new AvroDeserializer(keyAvroType, keySchema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) val ser = new AvroSerializer(valSchema, avroType, nullable = false) val de = new AvroDeserializer(avroType, valSchema, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, @@ -54,7 +58,7 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) { (Some(ckSer), Some(ckDe)) case None => (None, None) } - Some(AvroEncoderSpec(keySer, ser, de, ckSerDe._1, ckSerDe._2)) + Some(AvroEncoderSpec(keySer, keyDe, ser, de, ckSerDe._1, ckSerDe._2)) } else { None } @@ -63,14 +67,16 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) { def getTtlStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = { - val ttlKeySchema = getSingleKeyTTLRowSchema(keyEncoder.schema) + val ttlKeySchema = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) val ttlValSchema = StructType(Array(StructField("__dummy__", NullType))) StateStoreColFamilySchema( stateName, ttlKeySchema, ttlValSchema, Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), - avroEnc = getAvroSerde(keyEncoder.schema, ttlValSchema)) + avroEnc = getAvroSerde(ttlKeySchema, ttlValSchema)) } def getTtlStateSchema[T]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 5ebcc0d79c8a9..d51a5180907e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter} import org.apache.avro.io.{DecoderFactory, EncoderFactory} +import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.InternalRow @@ -49,6 +50,11 @@ object TransformWithStateKeyValueRowSchemaUtils { .add("expirationMs", LongType) .add("groupingKey", keySchema) + def getSingleKeyTTLAvroRowSchema: StructType = + new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + def getCompositeKeyTTLRowSchema( groupingKeySchema: StructType, userKeySchema: StructType): StructType = @@ -188,7 +194,7 @@ class AvroTypesEncoder[V]( valEncoder: Encoder[V], stateName: String, hasTtl: Boolean, - avroSerde: Option[AvroEncoderSpec]) extends StateTypesEncoder[V, Array[Byte]] { + avroSerde: Option[AvroEncoderSpec]) extends StateTypesEncoder[V, Array[Byte]] with Logging { val out = new ByteArrayOutputStream @@ -202,7 +208,7 @@ class AvroTypesEncoder[V]( private val keyAvroType = SchemaConverters.toAvroType(keySchema) // case class -> dataType - private val valSchema: StructType = valEncoder.schema + private val valSchema: StructType = getValueSchemaWithTTL(valEncoder.schema, hasTtl) // dataType -> avroType private val valueAvroType = SchemaConverters.toAvroType(valSchema) @@ -251,15 +257,38 @@ class AvroTypesEncoder[V]( } override def encodeValue(value: V, expirationMs: Long): Array[Byte] = { - throw new UnsupportedOperationException + val objRow: InternalRow = objToRowSerializer.apply(value).copy() // V -> InternalRow + val avroData = + avroSerde.get.valueSerializer.serialize( + InternalRow(objRow, expirationMs)) // InternalRow -> GenericDataRecord + out.reset() + + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + valueAvroType) // Defining Avro writer for this struct type + + writer.write(avroData, encoder) // GenericDataRecord -> bytes + encoder.flush() + out.toByteArray } override def decodeTtlExpirationMs(row: Array[Byte]): Option[Long] = { - throw new UnsupportedOperationException + val reader = new GenericDatumReader[Any](valueAvroType) + val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null) + val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord + val internalRow = avroSerde.get.valueDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow + val expirationMs = internalRow.getLong(1) + if (expirationMs == -1) { + None + } else { + Some(expirationMs) + } } override def isExpired(row: Array[Byte], batchTimestampMs: Long): Boolean = { - throw new UnsupportedOperationException + val expirationMs = decodeTtlExpirationMs(row) + expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs)) } } @@ -437,13 +466,14 @@ class CompositeKeyAvroRowEncoder[K, V]( /** Class for TTL with single key serialization */ class SingleKeyTTLEncoder( keyExprEnc: ExpressionEncoder[Any], - avroEncoderSpec: Option[AvroEncoderSpec] = None) { + avroEnc: Option[AvroEncoderSpec] = None) extends Logging { private lazy val out = new ByteArrayOutputStream private val ttlKeyProjection = UnsafeProjection.create( getSingleKeyTTLRowSchema(keyExprEnc.schema)) - private val ttlKeyType = SchemaConverters.toAvroType(keyExprEnc.schema) + private val ttlKeyAvroType = SchemaConverters.toAvroType( + getSingleKeyTTLAvroRowSchema) def encodeTTLRow(expirationMs: Long, groupingKey: UnsafeRow): UnsafeRow = { ttlKeyProjection.apply( @@ -451,14 +481,12 @@ class SingleKeyTTLEncoder( } def encodeTTLRow(expirationMs: Long, groupingKey: Array[Byte]): Array[Byte] = { - val objRow: InternalRow = InternalRow(expirationMs, groupingKey.asInstanceOf[InternalRow]) - val avroData = - avroEncoderSpec.get.valueSerializer.serialize(objRow) // InternalRow -> GenericDataRecord + val internalRow = InternalRow(expirationMs, groupingKey) + val avroData = avroEnc.get.keySerializer.serialize(internalRow) out.reset() - val encoder = EncoderFactory.get().directBinaryEncoder(out, null) val writer = new GenericDatumWriter[Any]( - ttlKeyType) // Defining Avro writer for this struct type + ttlKeyAvroType) // Defining Avro writer for this struct type writer.write(avroData, encoder) // GenericDataRecord -> bytes encoder.flush() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 96c67fb0c7385..a2e4f20bef0bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -173,7 +173,8 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, avroEnc = None, metrics) + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, schemas(stateName).avroEnc, + schemas("$ttl_" + stateName).avroEnc, metrics) ttlStates.add(valueStateWithTTL) TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars") @@ -283,7 +284,8 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, avroEnc = None, metrics) + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, schemas(stateName).avroEnc, + schemas("$ttl_" + stateName).avroEnc, metrics) TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars") ttlStates.add(listStateWithTTL) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 8811c59a50745..7b24f72d0db39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -18,12 +18,17 @@ package org.apache.spark.sql.execution.streaming import java.time.Duration +import org.apache.avro.generic.GenericDatumReader +import org.apache.avro.io.DecoderFactory + +import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder +import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, RangeKeyScanStateEncoderSpec, StateStore} import org.apache.spark.sql.types._ object StateTTLSchema { @@ -41,6 +46,10 @@ case class SingleKeyTTLRow( groupingKey: UnsafeRow, expirationMs: Long) +case class SingleKeyByteArrayTTLRow( + groupingKey: Array[Byte], + expirationMs: Long) + /** * Encapsulates the ttl row information stored in [[CompositeKeyTTLStateImpl]]. * @@ -80,14 +89,21 @@ abstract class SingleKeyTTLStateImpl( stateName: String, store: StateStore, keyExprEnc: ExpressionEncoder[Any], - ttlExpirationMs: Long) - extends TTLState { + ttlExpirationMs: Long, + avroEnc: Option[AvroEncoderSpec]) + extends TTLState with Logging { import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + private val usingAvro: Boolean = avroEnc.isDefined private val ttlColumnFamilyName = "$ttl_" + stateName - private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema) - private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc) + private val keySchema = if (usingAvro) { + getSingleKeyTTLAvroRowSchema + } else { + getSingleKeyTTLRowSchema(keyExprEnc.schema) + } + private val keyAvroType = SchemaConverters.toAvroType(keySchema) + private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc, avroEnc) // empty row used for values private val EMPTY_ROW = @@ -116,22 +132,57 @@ abstract class SingleKeyTTLStateImpl( store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) } + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = keyTTLRowEncoder.encodeTTLRow( + expirationMs, groupingKey) + store.put(encodedTtlKey, Array[Byte](4), ttlColumnFamilyName) + } + /** * Clears any state which has ttl older than [[ttlExpirationMs]]. */ override def clearExpiredState(): Long = { - val iterator = store.iterator(ttlColumnFamilyName) - var numValuesExpired = 0L - - iterator.takeWhile { kv => - val expirationMs = kv.key.getLong(0) - StateTTL.isExpired(expirationMs, ttlExpirationMs) - }.foreach { kv => - val groupingKey = kv.key.getStruct(1, keyExprEnc.schema.length) - numValuesExpired += clearIfExpired(groupingKey) - store.remove(kv.key, ttlColumnFamilyName) + if (usingAvro) { + val iterator = store.byteArrayIter(ttlColumnFamilyName) + var numValuesExpired = 0L + + iterator.takeWhile { kv => + val row = kv.key + val reader = new GenericDatumReader[Any](keyAvroType) + val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null) + val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord + val internalRow = avroEnc.get.keyDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow + val expirationMs = internalRow.getLong(0) + StateTTL.isExpired(expirationMs, ttlExpirationMs) + }.foreach { kv => + val row = kv.key + val reader = new GenericDatumReader[Any](keyAvroType) + val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null) + val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord + val internalRow = avroEnc.get.keyDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow + val groupingKey = internalRow.getBinary(1) + numValuesExpired += clearIfExpired(groupingKey) + store.remove(kv.key, ttlColumnFamilyName) + } + numValuesExpired + } else { + val iterator = store.iterator(ttlColumnFamilyName) + var numValuesExpired = 0L + + iterator.takeWhile { kv => + val expirationMs = kv.key.getLong(0) + StateTTL.isExpired(expirationMs, ttlExpirationMs) + }.foreach { kv => + val groupingKey = kv.key.getStruct(1, keyExprEnc.schema.length) + numValuesExpired += clearIfExpired(groupingKey) + store.remove(kv.key, ttlColumnFamilyName) + } + numValuesExpired } - numValuesExpired } private[sql] def ttlIndexIterator(): Iterator[SingleKeyTTLRow] = { @@ -174,6 +225,50 @@ abstract class SingleKeyTTLStateImpl( } } + private[sql] def ttlIndexByteIterator(): Iterator[SingleKeyByteArrayTTLRow] = { + val ttlIterator = store.byteArrayIter(ttlColumnFamilyName) + + new Iterator[SingleKeyByteArrayTTLRow] { + override def hasNext: Boolean = ttlIterator.hasNext + + override def next(): SingleKeyByteArrayTTLRow = { + val row = ttlIterator.next().key + val reader = new GenericDatumReader[Any](keyAvroType) + val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null) + val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord + val internalRow = avroEnc.get.keyDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow + SingleKeyByteArrayTTLRow( + expirationMs = internalRow.getLong(0), + groupingKey = internalRow.getBinary(1) + ) + } + } + } + private[sql] def getValuesInTTLState(groupingKey: Array[Byte]): Iterator[Long] = { + val ttlIterator = ttlIndexByteIterator() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val valueGroupingKey = nextTtlValue.groupingKey + if (valueGroupingKey sameElements groupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + result + } + } + } + /** * Clears the user state associated with this grouping key * if it has expired. This function is called by Spark to perform @@ -190,6 +285,8 @@ abstract class SingleKeyTTLStateImpl( * @return true if the state was cleared, false otherwise. */ def clearIfExpired(groupingKey: UnsafeRow): Long + + def clearIfExpired(groupingKey: Array[Byte]): Long } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index ac7a83ff65c21..eaa358a6298b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -45,11 +45,15 @@ class ValueStateImplWithTTL[S]( ttlConfig: TTLConfig, batchTimestampMs: Long, avroEnc: Option[AvroEncoderSpec], // TODO: Add Avro Encoding support for TTL + ttlAvroEnc: Option[AvroEncoderSpec], metrics: Map[String, SQLMetric] = Map.empty) extends SingleKeyTTLStateImpl( - stateName, store, keyExprEnc, batchTimestampMs) with ValueState[S] { + stateName, store, keyExprEnc, batchTimestampMs, ttlAvroEnc) with ValueState[S] { - private val stateTypesEncoder = UnsafeRowTypesEncoder(keyExprEnc, valEncoder, + private val usingAvro: Boolean = avroEnc.isDefined + private val avroTypesEncoder = new AvroTypesEncoder( + keyExprEnc, valEncoder, stateName, hasTtl = true, avroEnc) + private val unsafeRowTypesEncoder = UnsafeRowTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) @@ -74,13 +78,38 @@ class ValueStateImplWithTTL[S]( /** Function to return associated value with key if exists and null otherwise */ override def get(): S = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + if (usingAvro) { + getAvro() + } else { + getUnsafeRow() + } + } + + private def getUnsafeRow(): S = { + val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey() val retRow = store.get(encodedGroupingKey, stateName) if (retRow != null) { - val resState = stateTypesEncoder.decodeValue(retRow) + val resState = unsafeRowTypesEncoder.decodeValue(retRow) - if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) { + if (!unsafeRowTypesEncoder.isExpired(retRow, batchTimestampMs)) { + resState + } else { + null.asInstanceOf[S] + } + } else { + null.asInstanceOf[S] + } + } + + private def getAvro(): S = { + val encodedGroupingKey = avroTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = avroTypesEncoder.decodeValue(retRow) + + if (!avroTypesEncoder.isExpired(retRow, batchTimestampMs)) { resState } else { null.asInstanceOf[S] @@ -92,17 +121,30 @@ class ValueStateImplWithTTL[S]( /** Function to update and overwrite state associated with given key */ override def update(newState: S): Unit = { - val encodedValue = stateTypesEncoder.encodeValue(newState, ttlExpirationMs) - val serializedGroupingKey = stateTypesEncoder.encodeGroupingKey() - store.put(serializedGroupingKey, - encodedValue, stateName) - TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") - upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey) + if (usingAvro) { + val encodedValue = avroTypesEncoder.encodeValue(newState, ttlExpirationMs) + val serializedGroupingKey = avroTypesEncoder.encodeGroupingKey() + store.put(serializedGroupingKey, + encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") + upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey) + } else { + val encodedValue = unsafeRowTypesEncoder.encodeValue(newState, ttlExpirationMs) + val serializedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey() + store.put(serializedGroupingKey, + encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") + upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey) + } } /** Function to remove state for given key */ override def clear(): Unit = { - store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + if (usingAvro) { + store.remove(avroTypesEncoder.encodeGroupingKey(), stateName) + } else { + store.remove(unsafeRowTypesEncoder.encodeGroupingKey(), stateName) + } TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") clearTTLState() } @@ -112,7 +154,21 @@ class ValueStateImplWithTTL[S]( var result = 0L if (retRow != null) { - if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) { + if (unsafeRowTypesEncoder.isExpired(retRow, batchTimestampMs)) { + store.remove(groupingKey, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") + result = 1L + } + } + result + } + + def clearIfExpired(groupingKey: Array[Byte]): Long = { + val retRow = store.get(groupingKey, stateName) + + var result = 0L + if (retRow != null) { + if (avroTypesEncoder.isExpired(retRow, batchTimestampMs)) { store.remove(groupingKey, stateName) TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") result = 1L @@ -133,14 +189,26 @@ class ValueStateImplWithTTL[S]( * end of the micro-batch. */ private[sql] def getWithoutEnforcingTTL(): Option[S] = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() - val retRow = store.get(encodedGroupingKey, stateName) + if (usingAvro) { + val encodedGroupingKey = avroTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) - if (retRow != null) { - val resState = stateTypesEncoder.decodeValue(retRow) - Some(resState) + if (retRow != null) { + val resState = avroTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } } else { - None + val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = unsafeRowTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } } } @@ -148,16 +216,30 @@ class ValueStateImplWithTTL[S]( * Read the ttl value associated with the grouping key. */ private[sql] def getTTLValue(): Option[(S, Long)] = { - val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() - val retRow = store.get(encodedGroupingKey, stateName) + if (usingAvro) { + val encodedGroupingKey = avroTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) - // if the returned row is not null, we want to return the value associated with the - // ttlExpiration - if (retRow != null) { - val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(retRow) - ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(retRow), expiration)) + // if the returned row is not null, we want to return the value associated with the + // ttlExpiration + if (retRow != null) { + val ttlExpiration = avroTypesEncoder.decodeTtlExpirationMs(retRow) + ttlExpiration.map(expiration => (avroTypesEncoder.decodeValue(retRow), expiration)) + } else { + None + } } else { - None + val encodedGroupingKey = unsafeRowTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + // if the returned row is not null, we want to return the value associated with the + // ttlExpiration + if (retRow != null) { + val ttlExpiration = unsafeRowTypesEncoder.decodeTtlExpirationMs(retRow) + ttlExpiration.map(expiration => (unsafeRowTypesEncoder.decodeValue(retRow), expiration)) + } else { + None + } } } @@ -166,7 +248,11 @@ class ValueStateImplWithTTL[S]( * grouping key. */ private[sql] def getValuesInTTLState(): Iterator[Long] = { - getValuesInTTLState(stateTypesEncoder.encodeGroupingKey()) + if (usingAvro) { + getValuesInTTLState(avroTypesEncoder.encodeGroupingKey()) + } else { + getValuesInTTLState(unsafeRowTypesEncoder.encodeGroupingKey()) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index ba3036bbd652f..eeecd192735c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -1053,14 +1053,19 @@ class RangeKeyScanStateEncoder( // Skip column family prefix if present val startOffset = decodeKeyStartOffset + // Skip the version byte - we don't return this to the user + val dataStartOffset = startOffset + STATE_ENCODING_NUM_VERSION_BYTES + // Read prefix length - val prefixLen = Platform.getInt(keyBytes, startOffset) + val prefixLen = Platform.getInt(keyBytes, dataStartOffset) - // Total length of data we want to return - val totalLen = keyBytes.length - 4 - offsetForColFamilyPrefix + // Total length of data we want to return = total length - + // (version byte + prefix length int + column family prefix) + val totalLen = keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - 4 - + offsetForColFamilyPrefix val result = new Array[Byte](totalLen) - Platform.copyMemory(keyBytes, startOffset + 4, + Platform.copyMemory(keyBytes, dataStartOffset + 4, result, Platform.BYTE_ARRAY_OFFSET, totalLen) result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index a4fd9547ec5d8..e1c282e611224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -40,6 +40,7 @@ case class StateSchemaValidationResult( case class AvroEncoderSpec( keySerializer: AvroSerializer, + keyDeserializer: AvroDeserializer, valueSerializer: AvroSerializer, valueDeserializer: AvroDeserializer, compositeKeySerializer: Option[AvroSerializer] = None, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala index 2ddf69aa49e04..7bbe7b0bbd287 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import java.time.Duration import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -41,14 +41,14 @@ case class OutputEvent( * Test suite base for TransformWithState with TTL support. */ abstract class TransformWithStateTTLTest - extends StreamTest { + extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, InputEvent, OutputEvent] def getStateTTLMetricName: String - test("validate state is evicted at ttl expiry") { + testWithAvroEncoding("validate state is evicted at ttl expiry") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { dir => @@ -125,7 +125,7 @@ abstract class TransformWithStateTTLTest } } - test("validate state update updates the expiration timestamp") { + testWithAvroEncoding("validate state update updates the expiration timestamp") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputStream = MemoryStream[InputEvent] @@ -187,7 +187,7 @@ abstract class TransformWithStateTTLTest } } - test("validate state is evicted at ttl expiry for no data batch") { + testWithAvroEncoding("validate state is evicted at ttl expiry for no data batch") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val inputStream = MemoryStream[InputEvent] @@ -238,7 +238,7 @@ abstract class TransformWithStateTTLTest } } - test("validate only expired keys are removed from the state") { + testWithAvroEncoding("validate only expired keys are removed from the state") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index e7b394db0c3c7..effa2fa789a18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -184,7 +184,8 @@ class TTLProcessorWithCompositeTypes( } } -class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { +class TransformWithValueStateTTLSuite + extends TransformWithStateTTLTest with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ @@ -195,7 +196,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { override def getStateTTLMetricName: String = "numValueStateWithTTLVars" - test("validate multiple value states") { + testWithAvroEncoding("validate multiple value states") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { val ttlKey = "k1" From 25f3db7e36356fbefb8ec83c05d144490431e892 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 31 Oct 2024 18:43:16 -0700 Subject: [PATCH 3/3] creating encoder just for ttl/timers --- .../StateStoreColumnFamilySchemaUtils.scala | 8 +- .../streaming/StateTypesEncoderUtils.scala | 8 +- .../sql/execution/streaming/TTLState.scala | 13 +- .../streaming/state/RocksDBStateEncoder.scala | 430 +++++------------- .../streaming/state/StateStore.scala | 18 + .../state/RocksDBStateStoreSuite.scala | 95 ++-- 6 files changed, 216 insertions(+), 356 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 6bd6c959746f5..c61d28cbaa28f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema} -import org.apache.spark.sql.types.{BinaryType, LongType, NullType, StructField, StructType} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema, TTLRangeKeyScanStateEncoderSpec} +import org.apache.spark.sql.types.{BinaryType, NullType, StructField, StructType} object StateStoreColumnFamilySchemaUtils { @@ -68,14 +68,14 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo stateName: String, keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = { val ttlKeySchema = new StructType() - .add("expirationMs", LongType) + .add("expirationMs", BinaryType) .add("groupingKey", BinaryType) val ttlValSchema = StructType(Array(StructField("__dummy__", NullType))) StateStoreColFamilySchema( stateName, ttlKeySchema, ttlValSchema, - Some(RangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), + Some(TTLRangeKeyScanStateEncoderSpec(ttlKeySchema, Seq(0))), avroEnc = getAvroSerde(ttlKeySchema, ttlValSchema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index d51a5180907e6..e3ebc8654bb92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter} import org.apache.avro.io.{DecoderFactory, EncoderFactory} @@ -52,7 +53,7 @@ object TransformWithStateKeyValueRowSchemaUtils { def getSingleKeyTTLAvroRowSchema: StructType = new StructType() - .add("expirationMs", LongType) + .add("expirationMs", BinaryType) .add("groupingKey", BinaryType) def getCompositeKeyTTLRowSchema( @@ -481,8 +482,9 @@ class SingleKeyTTLEncoder( } def encodeTTLRow(expirationMs: Long, groupingKey: Array[Byte]): Array[Byte] = { - val internalRow = InternalRow(expirationMs, groupingKey) - val avroData = avroEnc.get.keySerializer.serialize(internalRow) + val expMsBytes = ByteBuffer.allocate(8).putLong(expirationMs).array() + val internalRow = InternalRow(expMsBytes, groupingKey) + val avroData = avroEnc.get.keySerializer.serialize(internalRow) // InternalRow -> Avro Record out.reset() val encoder = EncoderFactory.get().directBinaryEncoder(out, null) val writer = new GenericDatumWriter[Any]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 7b24f72d0db39..1836d472fd1c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.nio.ByteBuffer import java.time.Duration import org.apache.avro.generic.GenericDatumReader @@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, RangeKeyScanStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, RangeKeyScanStateEncoderSpec, StateStore, TTLRangeKeyScanStateEncoderSpec} import org.apache.spark.sql.types._ object StateTTLSchema { @@ -110,7 +111,8 @@ abstract class SingleKeyTTLStateImpl( UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema, TTL_VALUE_ROW_SCHEMA, - RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true) + if (usingAvro) TTLRangeKeyScanStateEncoderSpec(keySchema, Seq(0)) + else RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true) /** * This function will be called when clear() on State Variables @@ -137,6 +139,7 @@ abstract class SingleKeyTTLStateImpl( groupingKey: Array[Byte]): Unit = { val encodedTtlKey = keyTTLRowEncoder.encodeTTLRow( expirationMs, groupingKey) + logError(s"### encodedTtlKey: ${encodedTtlKey.mkString("Array(", ", ", ")")}") store.put(encodedTtlKey, Array[Byte](4), ttlColumnFamilyName) } @@ -155,7 +158,8 @@ abstract class SingleKeyTTLStateImpl( val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord val internalRow = avroEnc.get.keyDeserializer.deserialize( genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow - val expirationMs = internalRow.getLong(0) + val expMsBytes = internalRow.getBinary(0) + val expirationMs = ByteBuffer.wrap(expMsBytes).getLong() StateTTL.isExpired(expirationMs, ttlExpirationMs) }.foreach { kv => val row = kv.key @@ -238,8 +242,9 @@ abstract class SingleKeyTTLStateImpl( val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord val internalRow = avroEnc.get.keyDeserializer.deserialize( genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow + logError(s"### expirationMs: ${ByteBuffer.wrap(internalRow.getBinary(0)).getLong()}") SingleKeyByteArrayTTLRow( - expirationMs = internalRow.getLong(0), + expirationMs = ByteBuffer.wrap(internalRow.getBinary(0)).getLong(), groupingKey = internalRow.getBinary(1) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index eeecd192735c8..8447068b880ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -113,6 +113,10 @@ object RocksDBStateEncoder { new RangeKeyScanStateEncoder(keySchema, orderingOrdinals, useColumnFamilies, virtualColFamilyId) + case TTLRangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => + new TTLRangeKeyScanStateEncoder(keySchema, orderingOrdinals, + useColumnFamilies, virtualColFamilyId) + case _ => throw new IllegalArgumentException(s"Unsupported key state encoder spec: " + s"$keyStateEncoderSpec") @@ -729,346 +733,144 @@ class RangeKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true - override def encodePrefixKeyBytes(prefixKeyBytes: Array[Byte]): Array[Byte] = { - // For range scan prefix key, we need to transform the bytes to maintain ordering - // This requires parsing each field according to its type and applying the range-scan encoding + override def encodePrefixKeyBytes(prefixKeyBytes: Array[Byte]): Array[Byte] = + throw new UnsupportedOperationException - var pos = 0 - var totalSize = 0 - - // Calculate total size needed - each field needs marker byte + field size - rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => - totalSize += 1 // marker byte - val fieldSize = field.dataType match { - case BooleanType | ByteType => 1 - case ShortType => 2 - case IntegerType | FloatType => 4 - case LongType | DoubleType => 8 - case _ => throw new IllegalArgumentException(s"Unsupported type ${field.dataType}") - } - totalSize += fieldSize - } + override def encodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = + throw new UnsupportedOperationException - // Create byte array for the encoded fields - val encodedFields = new Array[Byte](totalSize) - var encodedPos = 0 - - // Process each field - rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => - val valueSize = field.dataType match { - case BooleanType | ByteType => 1 - case ShortType => 2 - case IntegerType | FloatType => 4 - case LongType | DoubleType => 8 - } + override def decodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = + throw new UnsupportedOperationException +} - // Check if we have enough bytes left to read the value - if (pos + valueSize > prefixKeyBytes.length) { - throw new IllegalArgumentException(s"Not enough bytes to read ${field.dataType}") - } - // Create buffer for reading the value in native byte order - val bbuf = ByteBuffer.wrap(prefixKeyBytes, pos, valueSize) - val bbufBE = ByteBuffer.allocate(valueSize) - bbufBE.order(ByteOrder.BIG_ENDIAN) - - field.dataType match { - case ByteType => - val value = bbuf.get() - encodedFields(encodedPos) = if (value == 0) { - 0x02.toByte // null marker - } else if (value < 0) { - 0x00.toByte // negative marker - } else { - 0x01.toByte // positive marker - } - encodedPos += 1 - bbufBE.put(value) - bbufBE.flip() - bbufBE.get(encodedFields, encodedPos, valueSize) - - case ShortType => - val value = bbuf.getShort() - encodedFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.putShort(value) - bbufBE.flip() - bbufBE.get(encodedFields, encodedPos, valueSize) - - case IntegerType => - val value = bbuf.getInt() - encodedFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.putInt(value) - bbufBE.flip() - bbufBE.get(encodedFields, encodedPos, valueSize) - - case LongType => - val value = bbuf.getLong() - encodedFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.putLong(value) - bbufBE.flip() - bbufBE.get(encodedFields, encodedPos, valueSize) - - case FloatType => - val value = bbuf.getFloat() - val bits = java.lang.Float.floatToRawIntBits(value) - encodedFields(encodedPos) = if (bits == 0) { - 0x02.toByte - } else if ((bits & floatSignBitMask) != 0) { - // Negative float - need to flip bits for correct ordering - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - val floatToWrite = if ((bits & floatSignBitMask) != 0) { - java.lang.Float.intBitsToFloat(bits ^ floatFlipBitMask) - } else value - bbufBE.putFloat(floatToWrite) - bbufBE.flip() - bbufBE.get(encodedFields, encodedPos, valueSize) - - case DoubleType => - val value = bbuf.getDouble() - val bits = java.lang.Double.doubleToRawLongBits(value) - encodedFields(encodedPos) = if (bits == 0) { - 0x02.toByte - } else if ((bits & doubleSignBitMask) != 0) { - // Negative double - need to flip bits for correct ordering - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - val doubleToWrite = if ((bits & doubleSignBitMask) != 0) { - java.lang.Double.longBitsToDouble(bits ^ doubleFlipBitMask) - } else value - bbufBE.putDouble(doubleToWrite) - bbufBE.flip() - bbufBE.get(encodedFields, encodedPos, valueSize) - } +// TODO: Change name for TWS +class TTLRangeKeyScanStateEncoder( + keySchema: StructType, + orderingOrdinals: Seq[Int], + useColumnFamilies: Boolean = false, + virtualColFamilyId: Option[Short] = None) + extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging { - encodedPos += valueSize - pos += valueSize + private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { + orderingOrdinals.map { ordinal => + val field = keySchema(ordinal) + (field, ordinal) } + } - // Now wrap these encoded fields with version and column family prefix if needed - val (result, startingOffset) = encodeColumnFamilyPrefix(encodedFields.length + 4) - Platform.putInt(result, startingOffset, encodedFields.length) - Platform.copyMemory(encodedFields, Platform.BYTE_ARRAY_OFFSET, - result, startingOffset + 4, encodedFields.length) - - result + private def isFixedSize(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType | _: BinaryType => true + case _ => false } - override def encodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = { - // For full key encoding, we need to: - // 1. Split the input bytes into range scan fields and remaining fields - // 2. Encode the range scan fields using the same logic as encodePrefixKeyBytes - // 3. Append the remaining fields - // 4. Add column family prefix if needed - - var pos = 0 - var totalSize = 0 - - // Calculate size needed for range scan fields - rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => - totalSize += 1 // marker byte - val fieldSize = field.dataType match { - case BooleanType | ByteType => 1 - case ShortType => 2 - case IntegerType | FloatType => 4 - case LongType | DoubleType => 8 - case _ => throw new IllegalArgumentException(s"Unsupported type ${field.dataType}") + // verify that only fixed sized columns are used for ordering + rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) => + if (!isFixedSize(field.dataType)) { + // NullType is technically fixed size, but not supported for ordering + if (field.dataType == NullType) { + throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, ordinal.toString) + } else { + throw StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, ordinal.toString) } - totalSize += fieldSize } + } - // Create array for range scan encoded fields - val encodedRangeFields = new Array[Byte](totalSize) - var encodedPos = 0 - - // Encode range scan fields - rangeScanKeyFieldsWithOrdinal.foreach { case (field, _) => - val valueSize = field.dataType match { - case BooleanType | ByteType => 1 - case ShortType => 2 - case IntegerType | FloatType => 4 - case LongType | DoubleType => 8 - } + private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { + 0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal => + val field = keySchema(ordinal) + (field, ordinal) + } + } - if (pos + valueSize > keyBytes.length) { - throw new IllegalArgumentException(s"Not enough bytes to read ${field.dataType}") - } + private val rangeScanKeyProjection: UnsafeProjection = { + val refs = rangeScanKeyFieldsWithOrdinal.map(x => + BoundReference(x._2, x._1.dataType, x._1.nullable)) + UnsafeProjection.create(refs) + } - val bbuf = ByteBuffer.wrap(keyBytes, pos, valueSize) - val bbufBE = ByteBuffer.allocate(valueSize) - bbufBE.order(ByteOrder.BIG_ENDIAN) - - field.dataType match { - case ByteType => - val value = bbuf.get() - encodedRangeFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.put(value) - bbufBE.flip() - bbufBE.get(encodedRangeFields, encodedPos, valueSize) - - case ShortType => - val value = bbuf.getShort() - encodedRangeFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.putShort(value) - bbufBE.flip() - bbufBE.get(encodedRangeFields, encodedPos, valueSize) - - case IntegerType => - val value = bbuf.getInt() - encodedRangeFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.putInt(value) - bbufBE.flip() - bbufBE.get(encodedRangeFields, encodedPos, valueSize) - - case LongType => - val value = bbuf.getLong() - encodedRangeFields(encodedPos) = if (value == 0) { - 0x02.toByte - } else if (value < 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - bbufBE.putLong(value) - bbufBE.flip() - bbufBE.get(encodedRangeFields, encodedPos, valueSize) - - case FloatType => - val value = bbuf.getFloat() - val bits = java.lang.Float.floatToRawIntBits(value) - encodedRangeFields(encodedPos) = if (bits == 0) { - 0x02.toByte - } else if ((bits & floatSignBitMask) != 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - val floatToWrite = if ((bits & floatSignBitMask) != 0) { - java.lang.Float.intBitsToFloat(bits ^ floatFlipBitMask) - } else value - bbufBE.putFloat(floatToWrite) - bbufBE.flip() - bbufBE.get(encodedRangeFields, encodedPos, valueSize) - - case DoubleType => - val value = bbuf.getDouble() - val bits = java.lang.Double.doubleToRawLongBits(value) - encodedRangeFields(encodedPos) = if (bits == 0) { - 0x02.toByte - } else if ((bits & doubleSignBitMask) != 0) { - 0x00.toByte - } else { - 0x01.toByte - } - encodedPos += 1 - val doubleToWrite = if ((bits & doubleSignBitMask) != 0) { - java.lang.Double.longBitsToDouble(bits ^ doubleFlipBitMask) - } else value - bbufBE.putDouble(doubleToWrite) - bbufBE.flip() - bbufBE.get(encodedRangeFields, encodedPos, valueSize) - } - encodedPos += valueSize - pos += valueSize - } - // Get remaining bytes - val remainingBytes = if (pos < keyBytes.length) { - val remaining = new Array[Byte](keyBytes.length - pos) - Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + pos, - remaining, Platform.BYTE_ARRAY_OFFSET, remaining.length) - remaining - } else Array.empty[Byte] - - // Create final result with version and column family prefix if needed - val (result, startingOffset) = encodeColumnFamilyPrefix( - 4 + encodedRangeFields.length + remainingBytes.length) - - Platform.putInt(result, startingOffset, encodedRangeFields.length) - Platform.copyMemory(encodedRangeFields, Platform.BYTE_ARRAY_OFFSET, - result, startingOffset + 4, encodedRangeFields.length) - - if (remainingBytes.nonEmpty) { - Platform.copyMemory(remainingBytes, Platform.BYTE_ARRAY_OFFSET, - result, startingOffset + 4 + encodedRangeFields.length, remainingBytes.length) - } + // bit masks used for checking sign or flipping all bits for negative float/double values + private val floatFlipBitMask = 0xFFFFFFFF + private val floatSignBitMask = 0x80000000 + + private val doubleFlipBitMask = 0xFFFFFFFFFFFFFFFFL + private val doubleSignBitMask = 0x8000000000000000L + + // Byte markers used to identify whether the value is null, negative or positive + // To ensure sorted ordering, we use the lowest byte value for negative numbers followed by + // positive numbers and then null values. + private val negativeValMarker: Byte = 0x00.toByte + private val positiveValMarker: Byte = 0x01.toByte + private val nullValMarker: Byte = 0x02.toByte + + override def supportPrefixKeyScan: Boolean = true + + override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = + throw new UnsupportedOperationException + + override def encodeKey(row: UnsafeRow): Array[Byte] = + throw new UnsupportedOperationException + + override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = + throw new UnsupportedOperationException + + private def encodePrefixKeyForRangeScan(value: Array[Byte]): Array[Byte] = { + value + } + + override def encodePrefixKeyBytes(prefixKeyBytes: Array[Byte]): Array[Byte] = { + // For range scan prefix key, we need to transform the bytes to maintain ordering + // This requires parsing each field according to its type and applying the range-scan encoding + + // Now wrap these encoded fields with version and column family prefix if needed + val (result, startingOffset) = encodeColumnFamilyPrefix(prefixKeyBytes.length + 4) + Platform.putInt(result, startingOffset, prefixKeyBytes.length) + Platform.copyMemory(prefixKeyBytes, Platform.BYTE_ARRAY_OFFSET, + result, startingOffset + 4, prefixKeyBytes.length) result } - override def decodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = { - if (keyBytes == null) return null + override def encodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = { + // Now wrap these encoded fields with version and column family prefix if needed + val (result, startingOffset) = encodeColumnFamilyPrefix(keyBytes.length + 4) + Platform.putInt(result, startingOffset, keyBytes.length) + Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET, + result, startingOffset + 4, keyBytes.length) - // Skip column family prefix if present - val startOffset = decodeKeyStartOffset + result + } - // Skip the version byte - we don't return this to the user - val dataStartOffset = startOffset + STATE_ENCODING_NUM_VERSION_BYTES + override def decodeKeyBytes(keyBytes: Array[Byte]): Array[Byte] = { + if (keyBytes == null) { + null + } else { + // Get the length of the actual key data from the first 4 bytes after the column family prefix + val prefixKeyEncodedLen = Platform.getInt(keyBytes, decodeKeyStartOffset) - // Read prefix length - val prefixLen = Platform.getInt(keyBytes, dataStartOffset) + // Calculate the total length of data to extract + // This is the encoded length minus: + // - 4 bytes for the length prefix + // - column family prefix bytes (if enabled) + val dataLength = keyBytes.length - 4 - offsetForColFamilyPrefix - // Total length of data we want to return = total length - - // (version byte + prefix length int + column family prefix) - val totalLen = keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - 4 - - offsetForColFamilyPrefix + // Create a new byte array to hold the decoded data + val result = new Array[Byte](dataLength) - val result = new Array[Byte](totalLen) - Platform.copyMemory(keyBytes, dataStartOffset + 4, - result, Platform.BYTE_ARRAY_OFFSET, totalLen) + // Copy the data bytes after the length prefix and column family prefix + Platform.copyMemory( + keyBytes, + decodeKeyStartOffset + 4, // Skip length prefix + result, + Platform.BYTE_ARRAY_OFFSET, + dataLength + ) - result + result + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index f0b6fdd41ba18..becbbb2b07d70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -395,6 +395,10 @@ object KeyStateEncoderSpec { val orderingOrdinals = m("orderingOrdinals"). asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt) RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) + case "TTLRangeKeyScanStateEncoderSpec" => + val orderingOrdinals = m("orderingOrdinals"). + asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt) + TTLRangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) case "PrefixKeyScanStateEncoderSpec" => val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt].toInt PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) @@ -435,6 +439,20 @@ case class RangeKeyScanStateEncoderSpec( } } +/** Encodes rows so that they can be range-scanned based on orderingOrdinals */ +case class TTLRangeKeyScanStateEncoderSpec( + keySchema: StructType, + orderingOrdinals: Seq[Int]) extends KeyStateEncoderSpec { + if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) { + throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString) + } + + override def jsonValue: JValue = { + ("keyStateEncoderType" -> JString("TTLRangeKeyScanStateEncoderSpec")) ~ + ("orderingOrdinals" -> orderingOrdinals.map(JInt(_))) + } +} + /** * Trait representing a provider that provide [[StateStore]] instances representing * versions of state data. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 318349a9a434c..222878302bc3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer import java.util.UUID import scala.collection.immutable import scala.util.Random +import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter} +import org.apache.avro.io.{DecoderFactory, EncoderFactory} import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter @@ -29,9 +33,12 @@ import org.apache.spark.{SparkConf, SparkUnsupportedOperationException} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.LocalSparkSession.withSparkSession import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils.getSingleKeyTTLAvroRowSchema import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -821,45 +828,71 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } - testWithColumnFamilies("rocksdb range scan - ordering cols and key schema cols are same", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - // use the same schema as value schema for single col key schema - tryWithProviderResource(newStoreProvider(valueSchema, - RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - valueSchema, valueSchema, - RangeKeyScanStateEncoderSpec(valueSchema, Seq(0))) - } + private val ttlKeyAvroType = SchemaConverters.toAvroType( + getSingleKeyTTLAvroRowSchema) + + def encodeTTLRow(expirationMs: Long): Array[Byte] = { + val groupingKey = new Array[Byte](8) + val out = new ByteArrayOutputStream + val expMsBytes = ByteBuffer.allocate(8).putLong(expirationMs).array() + val internalRow = InternalRow(expMsBytes, groupingKey) + val keySer = new AvroSerializer(getSingleKeyTTLAvroRowSchema, ttlKeyAvroType, nullable = false) + val avroData = keySer.serialize(internalRow) // InternalRow -> Avro Record + out.reset() + val encoder = EncoderFactory.get().directBinaryEncoder(out, null) + val writer = new GenericDatumWriter[Any]( + ttlKeyAvroType) // Defining Avro writer for this struct type + + writer.write(avroData, encoder) // GenericDataRecord -> bytes + encoder.flush() + out.toByteArray + } - val timerTimestamps = Seq(931, 8000, 452300, 4200, - -3545, -343, 133, -90, -8014490, -79247, - 90, 1, 2, 8, 3, 35, 6, 9, 5, -233) + def decodeTTLRow(row: Array[Byte]): Long = { + val avroOptions = AvroOptions(Map.empty) + val keyDe = new AvroDeserializer(ttlKeyAvroType, getSingleKeyTTLAvroRowSchema, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, + avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth) + val reader = new GenericDatumReader[Any](ttlKeyAvroType) + val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null) + val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord + val internalRow = keyDe.deserialize( + genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow + val expMsBytes = internalRow.getBinary(0) + ByteBuffer.wrap(expMsBytes).getLong() + } + + testWithAvroEncoding("Avro") { + val schema = getSingleKeyTTLAvroRowSchema + tryWithProviderResource(newStoreProvider(schema, + TTLRangeKeyScanStateEncoderSpec(schema, Seq(0)), useColumnFamilies = true)) { provider => + val store = provider.getStore(0) + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + schema, schema, + TTLRangeKeyScanStateEncoderSpec(schema, Seq(0))) + val timerTimestamps = Seq( + 1698765732000L, // Middle timestamp + 1698764432000L, // Earlier (20 min before) + 1698766932000L, // Latest (20 min after) + 1698765432000L, // Middle-ish + 1698764932000L, // Earlier but not earliest + 1698766432000L, // Later but not latest + 1698765532000L, // Another middle one + 1698764532000L // Another early one + ) timerTimestamps.foreach { ts => // non-timestamp col is of variable size - val keyRow = dataToValueRow(ts) - val valueRow = dataToValueRow(1) + val keyRow = encodeTTLRow(ts) + val valueRow = encodeTTLRow(1) store.put(keyRow, valueRow, cfName) - assert(valueRowToData(store.get(keyRow, cfName)) === 1) + assert(decodeTTLRow(store.get(keyRow, cfName)) === 1) } - - val result = store.iterator(cfName).map { kv => - valueRowToData(kv.key) + val result = store.byteArrayIter(cfName).map { kv => + decodeTTLRow(kv.key) }.toSeq assert(result === timerTimestamps.sorted) - - // also check for prefix scan - timerTimestamps.foreach { ts => - val prefix = dataToValueRow(ts) - val result = store.prefixScan(prefix, cfName).map { kv => - assert(valueRowToData(kv.value) === 1) - valueRowToData(kv.key) - }.toSeq - assert(result.size === 1) - } } }