Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,40 @@ class AvroStateEncoder(
out.toByteArray
}

/**
* Prepends a version byte to the beginning of a byte array.
* This is used to maintain backward compatibility and version control of
* the state encoding format.
*
* @param bytesToEncode The original byte array to prepend the version byte to
* @return A new byte array with the version byte prepended at the beginning
*/
private[sql] def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = {
val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
Platform.copyMemory(
bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
bytesToEncode.length)
encodedBytes
}

/**
* Removes the version byte from the beginning of a byte array.
* This is used when decoding state data to get back to the original encoded format.
*
* @param bytes The byte array containing the version byte at the start
* @return A new byte array with the version byte removed
*/
private[sql] def removeVersionByte(bytes: Array[Byte]): Array[Byte] = {
val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
Platform.copyMemory(
bytes, STATE_ENCODING_NUM_VERSION_BYTES + Platform.BYTE_ARRAY_OFFSET,
resultBytes, Platform.BYTE_ARRAY_OFFSET, resultBytes.length
)
resultBytes
}

/**
* This method takes a byte array written using Avro encoding, and
* deserializes to an UnsafeRow using the Avro deserializer
Expand Down Expand Up @@ -956,7 +990,7 @@ class AvroStateEncoder(
private val out = new ByteArrayOutputStream

override def encodeKey(row: UnsafeRow): Array[Byte] = {
keyStateEncoderSpec match {
val keyBytes = keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(_) =>
val avroRow =
encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out)
Expand All @@ -967,6 +1001,7 @@ class AvroStateEncoder(
encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out)
case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey")
}
prependVersionByte(keyBytes)
}

override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = {
Expand All @@ -978,8 +1013,8 @@ class AvroStateEncoder(
case _ => throw unsupportedOperationForKeyStateEncoder("encodeRemainingKey")
}
// prepend stateSchemaId to the remaining key portion
encodeWithStateSchemaId(
StateSchemaIdRow(currentKeySchemaId, avroRow))
prependVersionByte(encodeWithStateSchemaId(
StateSchemaIdRow(currentKeySchemaId, avroRow)))
}

/**
Expand Down Expand Up @@ -1118,16 +1153,18 @@ class AvroStateEncoder(
val encoder = EncoderFactory.get().binaryEncoder(out, null)
writer.write(record, encoder)
encoder.flush()
out.toByteArray
prependVersionByte(out.toByteArray)
}

override def encodeValue(row: UnsafeRow): Array[Byte] = {
val avroRow = encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
// prepend stateSchemaId to the Avro-encoded value portion
encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow))
prependVersionByte(
encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow)))
}

override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
override def decodeKey(rowBytes: Array[Byte]): UnsafeRow = {
val bytes = removeVersionByte(rowBytes)
keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(_) =>
val schemaIdRow = decodeStateSchemaIdRow(bytes)
Expand All @@ -1141,7 +1178,8 @@ class AvroStateEncoder(
}


override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
override def decodeRemainingKey(rowBytes: Array[Byte]): UnsafeRow = {
val bytes = removeVersionByte(rowBytes)
val schemaIdRow = decodeStateSchemaIdRow(bytes)
keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(_, _) =>
Expand Down Expand Up @@ -1174,7 +1212,8 @@ class AvroStateEncoder(
* @throws UnsupportedOperationException if a field's data type is not supported for range
* scan decoding
*/
override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
override def decodePrefixKeyForRangeScan(rowBytes: Array[Byte]): UnsafeRow = {
val bytes = removeVersionByte(rowBytes)
val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null)
val record = reader.read(null, decoder)
Expand Down Expand Up @@ -1257,7 +1296,8 @@ class AvroStateEncoder(
rowWriter.getRow()
}

override def decodeValue(bytes: Array[Byte]): UnsafeRow = {
override def decodeValue(rowBytes: Array[Byte]): UnsafeRow = {
val bytes = removeVersionByte(rowBytes)
val schemaIdRow = decodeStateSchemaIdRow(bytes)
val writerSchema = getStateSchemaProvider.getSchemaMetadataValue(
StateSchemaMetadataKey(
Expand Down Expand Up @@ -1648,45 +1688,11 @@ class NoPrefixKeyStateEncoder(
extends RocksDBKeyStateEncoder with Logging {

override def encodeKey(row: UnsafeRow): Array[Byte] = {
if (!useColumnFamilies) {
dataEncoder.encodeKey(row)
} else {
// First encode the row with the data encoder
val rowBytes = dataEncoder.encodeKey(row)

// Create data array with version byte
val dataWithVersion = new Array[Byte](STATE_ENCODING_NUM_VERSION_BYTES + rowBytes.length)
Platform.putByte(dataWithVersion, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
Platform.copyMemory(
rowBytes, Platform.BYTE_ARRAY_OFFSET,
dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
rowBytes.length
)

dataWithVersion
}
dataEncoder.encodeKey(row)
}

override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
if (!useColumnFamilies) {
dataEncoder.decodeKey(keyBytes)
} else if (keyBytes == null) {
null
} else {
val dataWithVersion = keyBytes

// Skip version byte to get to actual data
val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES

// Extract data bytes and decode using data encoder
val dataBytes = new Array[Byte](dataLength)
Platform.copyMemory(
dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
dataBytes, Platform.BYTE_ARRAY_OFFSET,
dataLength
)
dataEncoder.decodeKey(dataBytes)
}
dataEncoder.decodeKey(keyBytes)
}

override def supportPrefixKeyScan: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
// Verify the version encoded in first byte of the key and value byte arrays
assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)

// The test verifies that the actual key-value pair (kv) matches these expected byte patterns
// exactly using sameElements, which ensures the serialization format remains consistent and
// backward compatible. This is particularly important for state storage where the format
// needs to be stable across Spark versions.
val (expectedKey, expectedValue) = if (conf.stateStoreEncodingFormat == "avro") {
(Array(0, 0, 0, 2, 2, 97, 2, 0), Array(0, 0, 0, 2, 2))
} else {
(Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 24, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 97, 0, 0, 0, 0, 0, 0, 0),
Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0))
}
assert(kv.key.sameElements(expectedKey))
assert(kv.value.sameElements(expectedValue))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)

// Verify schema ID in remaining key bytes
val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(encodedRemainingKey)
val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(
encoder.removeVersionByte(encodedRemainingKey))
assert(decodedSchemaIdRow.schemaId === 18,
"Schema ID not preserved in prefix scan remaining key encoding")
}
Expand Down Expand Up @@ -462,7 +463,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)

// Verify schema ID in remaining key bytes
val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(encodedRemainingKey)
val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(
encoder.removeVersionByte(encodedRemainingKey))
assert(decodedSchemaIdRow.schemaId === 24,
"Schema ID not preserved in range scan remaining key encoding")

Expand Down Expand Up @@ -565,7 +567,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
val encodedValue = valueEncoder.encodeValue(value)

// Verify schema ID was included and preserved
val decodedSchemaIdRow = avroEncoder.decodeStateSchemaIdRow(encodedValue)
val decodedSchemaIdRow = avroEncoder.decodeStateSchemaIdRow(
avroEncoder.removeVersionByte(encodedValue))
assert(decodedSchemaIdRow.schemaId === 42,
"Schema ID not preserved in single value encoding")

Expand Down