From a3ec8453b66d46c14245f95b707d4bbe8a1c12cb Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 12 Feb 2025 15:53:36 -0800 Subject: [PATCH 01/16] doc --- ...ructured-streaming-transform-with-state.md | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 docs/streaming/structured-streaming-transform-with-state.md diff --git a/docs/streaming/structured-streaming-transform-with-state.md b/docs/streaming/structured-streaming-transform-with-state.md new file mode 100644 index 000000000000..bda01a529a3d --- /dev/null +++ b/docs/streaming/structured-streaming-transform-with-state.md @@ -0,0 +1,150 @@ +--- +layout: global +displayTitle: TransformWithState Programming Guide +title: TransformWithState Programming Guide +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +# TransformWithState Programming Guide + +The TransformWithState API enables stateful stream processing in Structured Streaming, allowing you to maintain and update state for each unique key in your streaming data. + +## Overview + +TransformWithState provides functionality to: +* Maintain state variables for each unique grouping key +* Process records that share the same key together +* Schedule timers for future processing +* Control state expiration via TTL (Time To Live) +* Evolve schemas safely over time +* Initialize state from existing sources + +## Key Concepts + +### State Variables + +State variables store data associated with each unique key. There are three types of state variables, each with specific operations: + +#### ValueState +Provides operations for single value state: +* `exists()`: Check if state exists +* `get()`: Get the state value if it exists +* `update(newState)`: Update the value +* `clear()`: Remove this state + +#### ListState +Provides operations for list state: +* `exists()`: Check if state exists +* `get()`: Get iterator over all values +* `put(newState)`: Set the entire list +* `appendValue(newState)`: Append single value +* `appendList(newState)`: Append array of values +* `clear()`: Remove this state + +#### MapState +Provides operations for map state: +* `exists()`: Check if state exists +* `getValue(key)`: Get value for map key +* `containsKey(key)`: Check if key exists +* `updateValue(key, value)`: Update/add key-value pair +* `iterator()`: Get iterator over all key-value pairs +* `keys()`: Get iterator over all keys +* `values()`: Get iterator over all values +* `removeKey(key)`: Remove specific key +* `clear()`: Remove all state + +### Configuration + +When initializing state variables in `init()`, you can configure: + + + + + + + + + + + + + + + +
OptionDescription
stateNameUnique identifier for the state variable
encoderSQL encoder for the state type
ttlConfigTTL configuration for state expiration. Use TTLConfig(duration) to set expiration time or TTLConfig.NONE to disable TTL
+ +State variables must be: +1. Declared as instance variables +2. Initialized in `init()` method using `getHandle.getValueState()`, `getListState()`, or `getMapState()` +3. Checked for existence using `.exists()` before first read (no default values) + +### Timers + +Timers allow scheduling callbacks for future execution in either processing-time or event-time. + +#### Timer Operations +Available through `getHandle` in StatefulProcessor: +* `registerTimer(expiryTimestampMs)`: Schedule a timer for the current key +* `deleteTimer(expiryTimestampMs)`: Delete a specific timer +* `listTimers()`: Get iterator of all registered timers for current key + +#### Timer Values +The `TimerValues` interface provides: +* `getCurrentProcessingTimeInMs()`: Current processing time as epoch milliseconds + * Constant throughout a streaming query trigger + * Use for processing-time timers +* `getCurrentWatermarkInMs()`: Current event time watermark + * Only available when watermark is set + * Returns 0 in first micro-batch + * Use for event-time timers + +#### Timer Handling +When a timer expires, `handleExpiredTimer` is called with: +* `key`: The grouping key for this timer +* `timerValues`: Current timer values +* `expiredTimerInfo`: Information about the expired timer + * `getExpiryTimeInMs()`: Get the timer's expiry time + +Key timer characteristics: +* Uniquely identified by timestamp within each key +* Cannot be created/deleted in `init()` +* Will fire even without data for that key +* Fire at or after scheduled time, never before +* Use `timerValues` timestamps for fault tolerance +* Can be created/deleted while processing data or handling another timer + +### Initial State +When providing initial state via `StatefulProcessorWithInitialState`: +* Available in first micro-batch only +* Must match grouping key schema of input rows +* Processed via `handleInitialState()` method +* Useful for migrating state from existing queries + +### Schema Evolution +Supported schema changes between query versions: +* Adding fields (new fields get null defaults) +* Removing fields +* Upcasting fields (e.g., Int to Long) +* Reordering fields + +Schema evolution rules: +* Changes validated at query start +* Fields matched by name, not position +* Must be compatible with all previous versions +* Null defaults for all fields + +[Example usage...] \ No newline at end of file From ce6b53345c8ed80c758e0436dcfc5cbdd4457c94 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 12 Feb 2025 16:48:06 -0800 Subject: [PATCH 02/16] adding avro encoding --- ...ructured-streaming-transform-with-state.md | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/docs/streaming/structured-streaming-transform-with-state.md b/docs/streaming/structured-streaming-transform-with-state.md index bda01a529a3d..0258494afdbf 100644 --- a/docs/streaming/structured-streaming-transform-with-state.md +++ b/docs/streaming/structured-streaming-transform-with-state.md @@ -134,17 +134,62 @@ When providing initial state via `StatefulProcessorWithInitialState`: * Processed via `handleInitialState()` method * Useful for migrating state from existing queries -### Schema Evolution -Supported schema changes between query versions: -* Adding fields (new fields get null defaults) -* Removing fields -* Upcasting fields (e.g., Int to Long) -* Reordering fields - -Schema evolution rules: -* Changes validated at query start -* Fields matched by name, not position -* Must be compatible with all previous versions -* Null defaults for all fields - -[Example usage...] \ No newline at end of file +### State Store Encoding + +The TransformWithState API supports two encoding formats for storing state information: UnsafeRow and Avro. This can be configured using the `spark.sql.streaming.stateStore.encodingFormat` configuration option. + +#### Encoding Configuration + +Set the encoding format using: + +```scala +spark.conf.set("spark.sql.streaming.stateStore.encodingFormat", "avro") // or "unsaferow" +``` + +The default value is "unsaferow". + +#### Encoding Options + +##### UnsafeRow Encoding +- Default encoding format +- Optimized for performance with minimal serialization overhead +- Best choice when schema evolution is not required +- More memory-efficient for simple data types + +##### Avro Encoding +- Added in Spark 4.0.0 +- Provides robust schema evolution capabilities +- Supports all schema evolution operations: + - Adding fields + - Removing fields + - Changing field orders + - Type promotion (e.g., Int to Long) +- Better suited for complex data types and nested structures +- Recommended when schema flexibility is important + +#### Best Practices + +1. Use UnsafeRow encoding when: + - Performance is the top priority + - Schema is stable and unlikely to change + - Working primarily with simple data types + - Memory efficiency is crucial + +2. Use Avro encoding when: + - Schema evolution is needed + - Working with complex or nested data structures + - Long-term maintainability is prioritized over raw performance + - Planning to modify state schemas over time + +3. Consider the tradeoffs: + - UnsafeRow provides better performance but limited schema evolution + - Avro offers more flexibility but with some performance overhead + +#### Schema Evolution with Avro + +When using Avro encoding, you can evolve your schemas following standard Avro compatibility rules: + +1. Adding Fields: New fields must have default values +2. Removing Fields: Existing data will retain removed fields +3. Type Promotion: Must follow valid Avro type promotion rules +4. Reordering: Fields can be reordered freely From 3f67baea41f8ef602fdb6f039b6ee2e0f5e094db Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 18 Feb 2025 11:55:58 -0800 Subject: [PATCH 03/16] init --- .../streaming/state/RocksDBStateEncoder.scala | 49 +++++++++---------- .../state/RocksDBStateStoreProvider.scala | 2 +- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index c7b324ec32e6..d406f1a9ab89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -1648,23 +1648,7 @@ class NoPrefixKeyStateEncoder( extends RocksDBKeyStateEncoder with Logging { override def encodeKey(row: UnsafeRow): Array[Byte] = { - if (!useColumnFamilies) { - dataEncoder.encodeKey(row) - } else { - // First encode the row with the data encoder - val rowBytes = dataEncoder.encodeKey(row) - - // Create data array with version byte - val dataWithVersion = new Array[Byte](STATE_ENCODING_NUM_VERSION_BYTES + rowBytes.length) - Platform.putByte(dataWithVersion, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) - Platform.copyMemory( - rowBytes, Platform.BYTE_ARRAY_OFFSET, - dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - rowBytes.length - ) - - dataWithVersion - } + dataEncoder.encodeKey(row) } override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { @@ -1675,17 +1659,28 @@ class NoPrefixKeyStateEncoder( } else { val dataWithVersion = keyBytes - // Skip version byte to get to actual data - val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES + val version = Platform.getByte(dataWithVersion, Platform.BYTE_ARRAY_OFFSET) + // For version 0, we were writing an extra version byte in the key row. + // We want to skip over this byte, as it is not necessary, and dealt with in + // dataEncoder.decodeKey. + // This is fixed for subsequent versions + val rowBytes = if (version == 0 && !dataEncoder.supportsSchemaEvolution) { + // Skip version byte to get to actual data + val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES + + // Extract data bytes and decode using data encoder + val dataBytes = new Array[Byte](dataLength) + Platform.copyMemory( + dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + dataBytes, Platform.BYTE_ARRAY_OFFSET, + dataLength + ) + dataBytes + } else { + dataWithVersion + } - // Extract data bytes and decode using data encoder - val dataBytes = new Array[Byte](dataLength) - Platform.copyMemory( - dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - dataBytes, Platform.BYTE_ARRAY_OFFSET, - dataLength - ) - dataEncoder.decodeKey(dataBytes) + dataEncoder.decodeKey(rowBytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index cd9fdb9469d6..32696fbac05a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -650,7 +650,7 @@ case class StateRowEncoderCacheKey( object RocksDBStateStoreProvider { // Version as a single byte that specifies the encoding of the row data in RocksDB val STATE_ENCODING_NUM_VERSION_BYTES = 1 - val STATE_ENCODING_VERSION: Byte = 0 + val STATE_ENCODING_VERSION: Byte = 1 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 val SCHEMA_ID_PREFIX_BYTES = 2 From 46b6b487a3354c036d975ada92611dbad97f3ed8 Mon Sep 17 00:00:00 2001 From: Eric Marnadi <132308037+ericm-db@users.noreply.github.com> Date: Tue, 18 Feb 2025 13:18:56 -0800 Subject: [PATCH 04/16] Delete docs/streaming/structured-streaming-transform-with-state.md --- ...ructured-streaming-transform-with-state.md | 195 ------------------ 1 file changed, 195 deletions(-) delete mode 100644 docs/streaming/structured-streaming-transform-with-state.md diff --git a/docs/streaming/structured-streaming-transform-with-state.md b/docs/streaming/structured-streaming-transform-with-state.md deleted file mode 100644 index 0258494afdbf..000000000000 --- a/docs/streaming/structured-streaming-transform-with-state.md +++ /dev/null @@ -1,195 +0,0 @@ ---- -layout: global -displayTitle: TransformWithState Programming Guide -title: TransformWithState Programming Guide -license: | - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. ---- - -# TransformWithState Programming Guide - -The TransformWithState API enables stateful stream processing in Structured Streaming, allowing you to maintain and update state for each unique key in your streaming data. - -## Overview - -TransformWithState provides functionality to: -* Maintain state variables for each unique grouping key -* Process records that share the same key together -* Schedule timers for future processing -* Control state expiration via TTL (Time To Live) -* Evolve schemas safely over time -* Initialize state from existing sources - -## Key Concepts - -### State Variables - -State variables store data associated with each unique key. There are three types of state variables, each with specific operations: - -#### ValueState -Provides operations for single value state: -* `exists()`: Check if state exists -* `get()`: Get the state value if it exists -* `update(newState)`: Update the value -* `clear()`: Remove this state - -#### ListState -Provides operations for list state: -* `exists()`: Check if state exists -* `get()`: Get iterator over all values -* `put(newState)`: Set the entire list -* `appendValue(newState)`: Append single value -* `appendList(newState)`: Append array of values -* `clear()`: Remove this state - -#### MapState -Provides operations for map state: -* `exists()`: Check if state exists -* `getValue(key)`: Get value for map key -* `containsKey(key)`: Check if key exists -* `updateValue(key, value)`: Update/add key-value pair -* `iterator()`: Get iterator over all key-value pairs -* `keys()`: Get iterator over all keys -* `values()`: Get iterator over all values -* `removeKey(key)`: Remove specific key -* `clear()`: Remove all state - -### Configuration - -When initializing state variables in `init()`, you can configure: - - - - - - - - - - - - - - - -
OptionDescription
stateNameUnique identifier for the state variable
encoderSQL encoder for the state type
ttlConfigTTL configuration for state expiration. Use TTLConfig(duration) to set expiration time or TTLConfig.NONE to disable TTL
- -State variables must be: -1. Declared as instance variables -2. Initialized in `init()` method using `getHandle.getValueState()`, `getListState()`, or `getMapState()` -3. Checked for existence using `.exists()` before first read (no default values) - -### Timers - -Timers allow scheduling callbacks for future execution in either processing-time or event-time. - -#### Timer Operations -Available through `getHandle` in StatefulProcessor: -* `registerTimer(expiryTimestampMs)`: Schedule a timer for the current key -* `deleteTimer(expiryTimestampMs)`: Delete a specific timer -* `listTimers()`: Get iterator of all registered timers for current key - -#### Timer Values -The `TimerValues` interface provides: -* `getCurrentProcessingTimeInMs()`: Current processing time as epoch milliseconds - * Constant throughout a streaming query trigger - * Use for processing-time timers -* `getCurrentWatermarkInMs()`: Current event time watermark - * Only available when watermark is set - * Returns 0 in first micro-batch - * Use for event-time timers - -#### Timer Handling -When a timer expires, `handleExpiredTimer` is called with: -* `key`: The grouping key for this timer -* `timerValues`: Current timer values -* `expiredTimerInfo`: Information about the expired timer - * `getExpiryTimeInMs()`: Get the timer's expiry time - -Key timer characteristics: -* Uniquely identified by timestamp within each key -* Cannot be created/deleted in `init()` -* Will fire even without data for that key -* Fire at or after scheduled time, never before -* Use `timerValues` timestamps for fault tolerance -* Can be created/deleted while processing data or handling another timer - -### Initial State -When providing initial state via `StatefulProcessorWithInitialState`: -* Available in first micro-batch only -* Must match grouping key schema of input rows -* Processed via `handleInitialState()` method -* Useful for migrating state from existing queries - -### State Store Encoding - -The TransformWithState API supports two encoding formats for storing state information: UnsafeRow and Avro. This can be configured using the `spark.sql.streaming.stateStore.encodingFormat` configuration option. - -#### Encoding Configuration - -Set the encoding format using: - -```scala -spark.conf.set("spark.sql.streaming.stateStore.encodingFormat", "avro") // or "unsaferow" -``` - -The default value is "unsaferow". - -#### Encoding Options - -##### UnsafeRow Encoding -- Default encoding format -- Optimized for performance with minimal serialization overhead -- Best choice when schema evolution is not required -- More memory-efficient for simple data types - -##### Avro Encoding -- Added in Spark 4.0.0 -- Provides robust schema evolution capabilities -- Supports all schema evolution operations: - - Adding fields - - Removing fields - - Changing field orders - - Type promotion (e.g., Int to Long) -- Better suited for complex data types and nested structures -- Recommended when schema flexibility is important - -#### Best Practices - -1. Use UnsafeRow encoding when: - - Performance is the top priority - - Schema is stable and unlikely to change - - Working primarily with simple data types - - Memory efficiency is crucial - -2. Use Avro encoding when: - - Schema evolution is needed - - Working with complex or nested data structures - - Long-term maintainability is prioritized over raw performance - - Planning to modify state schemas over time - -3. Consider the tradeoffs: - - UnsafeRow provides better performance but limited schema evolution - - Avro offers more flexibility but with some performance overhead - -#### Schema Evolution with Avro - -When using Avro encoding, you can evolve your schemas following standard Avro compatibility rules: - -1. Adding Fields: New fields must have default values -2. Removing Fields: Existing data will retain removed fields -3. Type Promotion: Must follow valid Avro type promotion rules -4. Reordering: Fields can be reordered freely From 519f6c9221c7eb8e004a779f50ba0c7c6bac6304 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 18 Feb 2025 13:20:12 -0800 Subject: [PATCH 05/16] stuff --- .../streaming/state/RocksDBStateEncoder.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index d406f1a9ab89..1de8195cee24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -892,7 +892,16 @@ class AvroStateEncoder( valueAvroType) // Defining Avro writer for this struct type writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array encoder.flush() - out.toByteArray + val bytesToEncode = out.toByteArray + // prepend version byte + val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) + Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) + // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. + Platform.copyMemory( + bytesToEncode, Platform.BYTE_ARRAY_OFFSET, + encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + bytesToEncode.length) + encodedBytes } /** @@ -907,7 +916,9 @@ class AvroStateEncoder( if (valueBytes != null) { val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( - valueBytes, 0, valueBytes.length, null) + valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_VERSION, + valueBytes.length - Platform.BYTE_ARRAY_OFFSET - STATE_ENCODING_VERSION, + null) // bytes -> Avro.GenericDataRecord val genericData = reader.read(null, decoder) // Avro.GenericDataRecord -> InternalRow @@ -940,7 +951,9 @@ class AvroStateEncoder( if (valueBytes != null) { val reader = new GenericDatumReader[Any](writerSchema, readerSchema) val decoder = DecoderFactory.get().binaryDecoder( - valueBytes, 0, valueBytes.length, null) + valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_VERSION, + valueBytes.length - Platform.BYTE_ARRAY_OFFSET - STATE_ENCODING_VERSION, + null) // bytes -> Avro.GenericDataRecord val genericData = reader.read(null, decoder) // Avro.GenericDataRecord -> InternalRow From 8e5e6edfd3bddb1d09f0ffe6a7f6ce30ec367d29 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 18 Feb 2025 13:32:45 -0800 Subject: [PATCH 06/16] stuff --- .../streaming/state/RocksDBStateEncoder.scala | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 1de8195cee24..eaf213f705d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -892,33 +892,43 @@ class AvroStateEncoder( valueAvroType) // Defining Avro writer for this struct type writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array encoder.flush() - val bytesToEncode = out.toByteArray - // prepend version byte + prependVersionByte(out.toByteArray) + } + + private def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. Platform.copyMemory( - bytesToEncode, Platform.BYTE_ARRAY_OFFSET, + bytesToEncode, 0, encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, bytesToEncode.length) encodedBytes } + private def removeVersionByte(bytes: Array[Byte]): Array[Byte] = { + val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) + Platform.copyMemory( + bytes, STATE_ENCODING_NUM_VERSION_BYTES + Platform.BYTE_ARRAY_OFFSET, + resultBytes, 0, resultBytes.length + ) + resultBytes + } + /** * This method takes a byte array written using Avro encoding, and * deserializes to an UnsafeRow using the Avro deserializer */ def decodeFromAvroToUnsafeRow( - valueBytes: Array[Byte], + b: Array[Byte], avroDeserializer: AvroDeserializer, valueAvroType: Schema, valueProj: UnsafeProjection): UnsafeRow = { - if (valueBytes != null) { + if (b != null) { + val valueBytes = removeVersionByte(b) val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( - valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_VERSION, - valueBytes.length - Platform.BYTE_ARRAY_OFFSET - STATE_ENCODING_VERSION, - null) + valueBytes, 0, valueBytes.length, null) // bytes -> Avro.GenericDataRecord val genericData = reader.read(null, decoder) // Avro.GenericDataRecord -> InternalRow @@ -943,17 +953,16 @@ class AvroStateEncoder( * @return The deserialized UnsafeRow, or null if input bytes are null */ def decodeFromAvroToUnsafeRow( - valueBytes: Array[Byte], + b: Array[Byte], avroDeserializer: AvroDeserializer, writerSchema: Schema, readerSchema: Schema, valueProj: UnsafeProjection): UnsafeRow = { - if (valueBytes != null) { + if (b != null) { + val valueBytes = removeVersionByte(b) val reader = new GenericDatumReader[Any](writerSchema, readerSchema) val decoder = DecoderFactory.get().binaryDecoder( - valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_VERSION, - valueBytes.length - Platform.BYTE_ARRAY_OFFSET - STATE_ENCODING_VERSION, - null) + valueBytes, 0, valueBytes.length, null) // bytes -> Avro.GenericDataRecord val genericData = reader.read(null, decoder) // Avro.GenericDataRecord -> InternalRow From 231cb77fd158dbc002e490555eaf6f457913c659 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 09:06:43 -0800 Subject: [PATCH 07/16] renaming param --- .../streaming/state/RocksDBStateEncoder.scala | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index eaf213f705d0..49eb35d1bcab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -898,7 +898,6 @@ class AvroStateEncoder( private def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) - // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. Platform.copyMemory( bytesToEncode, 0, encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, @@ -920,12 +919,12 @@ class AvroStateEncoder( * deserializes to an UnsafeRow using the Avro deserializer */ def decodeFromAvroToUnsafeRow( - b: Array[Byte], - avroDeserializer: AvroDeserializer, - valueAvroType: Schema, - valueProj: UnsafeProjection): UnsafeRow = { - if (b != null) { - val valueBytes = removeVersionByte(b) + rowBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + if (rowBytes != null) { + val valueBytes = removeVersionByte(rowBytes) val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( valueBytes, 0, valueBytes.length, null) @@ -945,7 +944,7 @@ class AvroStateEncoder( * This method takes a byte array written using Avro encoding, and * deserializes to an UnsafeRow using the Avro deserializer * - * @param valueBytes The raw bytes containing Avro-encoded data + * @param rowBytes The raw bytes containing Avro-encoded data * @param avroDeserializer Custom deserializer to convert Avro records to InternalRows * @param writerSchema The Avro schema used when writing the data * @param readerSchema The Avro schema to use for reading (may be different from writer schema) @@ -953,13 +952,13 @@ class AvroStateEncoder( * @return The deserialized UnsafeRow, or null if input bytes are null */ def decodeFromAvroToUnsafeRow( - b: Array[Byte], - avroDeserializer: AvroDeserializer, - writerSchema: Schema, - readerSchema: Schema, - valueProj: UnsafeProjection): UnsafeRow = { - if (b != null) { - val valueBytes = removeVersionByte(b) + rowBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + writerSchema: Schema, + readerSchema: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + if (rowBytes != null) { + val valueBytes = removeVersionByte(rowBytes) val reader = new GenericDatumReader[Any](writerSchema, readerSchema) val decoder = DecoderFactory.get().binaryDecoder( valueBytes, 0, valueBytes.length, null) @@ -1686,6 +1685,8 @@ class NoPrefixKeyStateEncoder( // We want to skip over this byte, as it is not necessary, and dealt with in // dataEncoder.decodeKey. // This is fixed for subsequent versions + // When Avro encoding was used, we were not writing any version byte, and the + // first byte was instead used for schema ID. val rowBytes = if (version == 0 && !dataEncoder.supportsSchemaEvolution) { // Skip version byte to get to actual data val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES From d9fda36b7f00776a44cf6b816c516e19f6cd1c84 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 09:29:19 -0800 Subject: [PATCH 08/16] tests pass --- .../streaming/state/RocksDBStateStoreSuite.scala | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 4d939db8796b..b61ef9505e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -74,9 +74,19 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(iter.hasNext) val kv = iter.next() + // For Avro encoding, the format is + // |--SCHEMA ID (2 bytes)--|--VERSION ID (1 byte)--|--DATA--| + val offset = if (conf.stateStoreEncodingFormat == "avro") { + SCHEMA_ID_PREFIX_BYTES + } else { + 0 + } + // Verify the version encoded in first byte of the key and value byte arrays - assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) - assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET + offset) === + STATE_ENCODING_VERSION) + assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET + offset) === + STATE_ENCODING_VERSION) } } From 891b9c7d5d6f7a1bde534b3d7801214d1bdfe03b Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 12:24:35 -0800 Subject: [PATCH 09/16] prepending before schema id --- .../streaming/state/RocksDBStateEncoder.scala | 71 ++++++------------- .../state/RocksDBStateStoreSuite.scala | 14 +--- 2 files changed, 24 insertions(+), 61 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 49eb35d1bcab..5f73036f250d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -892,7 +892,7 @@ class AvroStateEncoder( valueAvroType) // Defining Avro writer for this struct type writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array encoder.flush() - prependVersionByte(out.toByteArray) + out.toByteArray } private def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { @@ -919,12 +919,11 @@ class AvroStateEncoder( * deserializes to an UnsafeRow using the Avro deserializer */ def decodeFromAvroToUnsafeRow( - rowBytes: Array[Byte], + valueBytes: Array[Byte], avroDeserializer: AvroDeserializer, valueAvroType: Schema, valueProj: UnsafeProjection): UnsafeRow = { - if (rowBytes != null) { - val valueBytes = removeVersionByte(rowBytes) + if (valueBytes != null) { val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( valueBytes, 0, valueBytes.length, null) @@ -944,7 +943,7 @@ class AvroStateEncoder( * This method takes a byte array written using Avro encoding, and * deserializes to an UnsafeRow using the Avro deserializer * - * @param rowBytes The raw bytes containing Avro-encoded data + * @param valueBytes The raw bytes containing Avro-encoded data * @param avroDeserializer Custom deserializer to convert Avro records to InternalRows * @param writerSchema The Avro schema used when writing the data * @param readerSchema The Avro schema to use for reading (may be different from writer schema) @@ -952,13 +951,12 @@ class AvroStateEncoder( * @return The deserialized UnsafeRow, or null if input bytes are null */ def decodeFromAvroToUnsafeRow( - rowBytes: Array[Byte], + valueBytes: Array[Byte], avroDeserializer: AvroDeserializer, writerSchema: Schema, readerSchema: Schema, valueProj: UnsafeProjection): UnsafeRow = { - if (rowBytes != null) { - val valueBytes = removeVersionByte(rowBytes) + if (valueBytes != null) { val reader = new GenericDatumReader[Any](writerSchema, readerSchema) val decoder = DecoderFactory.get().binaryDecoder( valueBytes, 0, valueBytes.length, null) @@ -977,7 +975,7 @@ class AvroStateEncoder( private val out = new ByteArrayOutputStream override def encodeKey(row: UnsafeRow): Array[Byte] = { - keyStateEncoderSpec match { + val keyBytes = keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(_) => val avroRow = encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out) @@ -988,6 +986,7 @@ class AvroStateEncoder( encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out) case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey") } + prependVersionByte(keyBytes) } override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = { @@ -999,8 +998,8 @@ class AvroStateEncoder( case _ => throw unsupportedOperationForKeyStateEncoder("encodeRemainingKey") } // prepend stateSchemaId to the remaining key portion - encodeWithStateSchemaId( - StateSchemaIdRow(currentKeySchemaId, avroRow)) + prependVersionByte(encodeWithStateSchemaId( + StateSchemaIdRow(currentKeySchemaId, avroRow))) } /** @@ -1139,16 +1138,18 @@ class AvroStateEncoder( val encoder = EncoderFactory.get().binaryEncoder(out, null) writer.write(record, encoder) encoder.flush() - out.toByteArray + prependVersionByte(out.toByteArray) } override def encodeValue(row: UnsafeRow): Array[Byte] = { val avroRow = encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out) // prepend stateSchemaId to the Avro-encoded value portion - encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow)) + prependVersionByte( + encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow))) } - override def decodeKey(bytes: Array[Byte]): UnsafeRow = { + override def decodeKey(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(_) => val schemaIdRow = decodeStateSchemaIdRow(bytes) @@ -1162,7 +1163,8 @@ class AvroStateEncoder( } - override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = { + override def decodeRemainingKey(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) val schemaIdRow = decodeStateSchemaIdRow(bytes) keyStateEncoderSpec match { case PrefixKeyScanStateEncoderSpec(_, _) => @@ -1195,7 +1197,8 @@ class AvroStateEncoder( * @throws UnsupportedOperationException if a field's data type is not supported for range * scan decoding */ - override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = { + override def decodePrefixKeyForRangeScan(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType) val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) val record = reader.read(null, decoder) @@ -1278,7 +1281,8 @@ class AvroStateEncoder( rowWriter.getRow() } - override def decodeValue(bytes: Array[Byte]): UnsafeRow = { + override def decodeValue(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) val schemaIdRow = decodeStateSchemaIdRow(bytes) val writerSchema = getStateSchemaProvider.getSchemaMetadataValue( StateSchemaMetadataKey( @@ -1673,38 +1677,7 @@ class NoPrefixKeyStateEncoder( } override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { - if (!useColumnFamilies) { - dataEncoder.decodeKey(keyBytes) - } else if (keyBytes == null) { - null - } else { - val dataWithVersion = keyBytes - - val version = Platform.getByte(dataWithVersion, Platform.BYTE_ARRAY_OFFSET) - // For version 0, we were writing an extra version byte in the key row. - // We want to skip over this byte, as it is not necessary, and dealt with in - // dataEncoder.decodeKey. - // This is fixed for subsequent versions - // When Avro encoding was used, we were not writing any version byte, and the - // first byte was instead used for schema ID. - val rowBytes = if (version == 0 && !dataEncoder.supportsSchemaEvolution) { - // Skip version byte to get to actual data - val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES - - // Extract data bytes and decode using data encoder - val dataBytes = new Array[Byte](dataLength) - Platform.copyMemory( - dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - dataBytes, Platform.BYTE_ARRAY_OFFSET, - dataLength - ) - dataBytes - } else { - dataWithVersion - } - - dataEncoder.decodeKey(rowBytes) - } + dataEncoder.decodeKey(keyBytes) } override def supportPrefixKeyScan: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index b61ef9505e01..4d939db8796b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -74,19 +74,9 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(iter.hasNext) val kv = iter.next() - // For Avro encoding, the format is - // |--SCHEMA ID (2 bytes)--|--VERSION ID (1 byte)--|--DATA--| - val offset = if (conf.stateStoreEncodingFormat == "avro") { - SCHEMA_ID_PREFIX_BYTES - } else { - 0 - } - // Verify the version encoded in first byte of the key and value byte arrays - assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET + offset) === - STATE_ENCODING_VERSION) - assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET + offset) === - STATE_ENCODING_VERSION) + assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) } } From f57d3d10db4a1f609003acfd14127c613b54036a Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 12:25:22 -0800 Subject: [PATCH 10/16] resetting version --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 32696fbac05a..cd9fdb9469d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -650,7 +650,7 @@ case class StateRowEncoderCacheKey( object RocksDBStateStoreProvider { // Version as a single byte that specifies the encoding of the row data in RocksDB val STATE_ENCODING_NUM_VERSION_BYTES = 1 - val STATE_ENCODING_VERSION: Byte = 1 + val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 val SCHEMA_ID_PREFIX_BYTES = 2 From 6b64468690c9e64278eb494e1f0bdcbf4f40284c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 12:26:14 -0800 Subject: [PATCH 11/16] spacing --- .../streaming/state/RocksDBStateEncoder.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 5f73036f250d..46abc4ad81d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -919,10 +919,10 @@ class AvroStateEncoder( * deserializes to an UnsafeRow using the Avro deserializer */ def decodeFromAvroToUnsafeRow( - valueBytes: Array[Byte], - avroDeserializer: AvroDeserializer, - valueAvroType: Schema, - valueProj: UnsafeProjection): UnsafeRow = { + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { if (valueBytes != null) { val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( @@ -951,11 +951,11 @@ class AvroStateEncoder( * @return The deserialized UnsafeRow, or null if input bytes are null */ def decodeFromAvroToUnsafeRow( - valueBytes: Array[Byte], - avroDeserializer: AvroDeserializer, - writerSchema: Schema, - readerSchema: Schema, - valueProj: UnsafeProjection): UnsafeRow = { + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + writerSchema: Schema, + readerSchema: Schema, + valueProj: UnsafeProjection): UnsafeRow = { if (valueBytes != null) { val reader = new GenericDatumReader[Any](writerSchema, readerSchema) val decoder = DecoderFactory.get().binaryDecoder( From 9ea0d749d408f7f466d12a71eddb833dcf831e77 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 12:26:52 -0800 Subject: [PATCH 12/16] spacing --- .../execution/streaming/state/RocksDBStateEncoder.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 46abc4ad81d3..9ea47622dd93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -919,10 +919,10 @@ class AvroStateEncoder( * deserializes to an UnsafeRow using the Avro deserializer */ def decodeFromAvroToUnsafeRow( - valueBytes: Array[Byte], - avroDeserializer: AvroDeserializer, - valueAvroType: Schema, - valueProj: UnsafeProjection): UnsafeRow = { + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { if (valueBytes != null) { val reader = new GenericDatumReader[Any](valueAvroType) val decoder = DecoderFactory.get().binaryDecoder( From 92cb3c9086456c8f158ed80358014aae96d5b05a Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 13:04:07 -0800 Subject: [PATCH 13/16] Platform.BYTE_ARRAY_OFFSET --- .../sql/execution/streaming/state/RocksDBStateEncoder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 9ea47622dd93..519874d722f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -899,7 +899,7 @@ class AvroStateEncoder( val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) Platform.copyMemory( - bytesToEncode, 0, + bytesToEncode, Platform.BYTE_ARRAY_OFFSET, encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, bytesToEncode.length) encodedBytes @@ -909,7 +909,7 @@ class AvroStateEncoder( val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) Platform.copyMemory( bytes, STATE_ENCODING_NUM_VERSION_BYTES + Platform.BYTE_ARRAY_OFFSET, - resultBytes, 0, resultBytes.length + resultBytes, Platform.BYTE_ARRAY_OFFSET, resultBytes.length ) resultBytes } From 066834728aec64a468ce9c7769843c872177dbf5 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 13:13:07 -0800 Subject: [PATCH 14/16] adding test --- .../streaming/state/RocksDBStateStoreSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 4d939db8796b..34af68b011ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -77,6 +77,16 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // Verify the version encoded in first byte of the key and value byte arrays assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + + val (expectedKey, expectedValue) = if (conf.stateStoreEncodingFormat == "avro") { + (Array(0, 0, 0, 2, 2, 97, 2, 0), Array(0, 0, 0, 2, 2)) + } else { + (Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 24, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 97, 0, 0, 0, 0, 0, 0, 0), + Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0)) + } + assert(kv.key.sameElements(expectedKey)) + assert(kv.value.sameElements(expectedValue)) } } From 2a5a7c20c019fc0d6ff405750e852dead5220afa Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 19 Feb 2025 19:36:16 -0800 Subject: [PATCH 15/16] comments --- .../streaming/state/RocksDBStateEncoder.scala | 15 +++++++++++++++ .../streaming/state/RocksDBStateStoreSuite.scala | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 519874d722f2..1267c2a138bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -895,6 +895,14 @@ class AvroStateEncoder( out.toByteArray } + /** + * Prepends a version byte to the beginning of a byte array. + * This is used to maintain backward compatibility and version control of + * the state encoding format. + * + * @param bytesToEncode The original byte array to prepend the version byte to + * @return A new byte array with the version byte prepended at the beginning + */ private def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) @@ -905,6 +913,13 @@ class AvroStateEncoder( encodedBytes } + /** + * Removes the version byte from the beginning of a byte array. + * This is used when decoding state data to get back to the original encoded format. + * + * @param bytes The byte array containing the version byte at the start + * @return A new byte array with the version byte removed + */ private def removeVersionByte(bytes: Array[Byte]): Array[Byte] = { val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) Platform.copyMemory( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 34af68b011ed..5aea0077e2aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -78,6 +78,10 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + // The test verifies that the actual key-value pair (kv) matches these expected byte patterns + // exactly using sameElements, which ensures the serialization format remains consistent and + // backward compatible. This is particularly important for state storage where the format + // needs to be stable across Spark versions. val (expectedKey, expectedValue) = if (conf.stateStoreEncodingFormat == "avro") { (Array(0, 0, 0, 2, 2, 97, 2, 0), Array(0, 0, 0, 2, 2)) } else { From ecb32531cb93a064cba28e07eb44afd576b51b5c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 20 Feb 2025 08:52:42 -0800 Subject: [PATCH 16/16] fixed tests --- .../execution/streaming/state/RocksDBStateEncoder.scala | 4 ++-- .../sql/execution/streaming/state/RocksDBSuite.scala | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 1267c2a138bd..cf5f8ba5f2eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -903,7 +903,7 @@ class AvroStateEncoder( * @param bytesToEncode The original byte array to prepend the version byte to * @return A new byte array with the version byte prepended at the beginning */ - private def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { + private[sql] def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) Platform.copyMemory( @@ -920,7 +920,7 @@ class AvroStateEncoder( * @param bytes The byte array containing the version byte at the start * @return A new byte array with the version byte removed */ - private def removeVersionByte(bytes: Array[Byte]): Array[Byte] = { + private[sql] def removeVersionByte(bytes: Array[Byte]): Array[Byte] = { val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) Platform.copyMemory( bytes, STATE_ENCODING_NUM_VERSION_BYTES + Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 7d4614d59973..50240c0605e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -433,7 +433,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow) // Verify schema ID in remaining key bytes - val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(encodedRemainingKey) + val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow( + encoder.removeVersionByte(encodedRemainingKey)) assert(decodedSchemaIdRow.schemaId === 18, "Schema ID not preserved in prefix scan remaining key encoding") } @@ -462,7 +463,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow) // Verify schema ID in remaining key bytes - val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(encodedRemainingKey) + val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow( + encoder.removeVersionByte(encodedRemainingKey)) assert(decodedSchemaIdRow.schemaId === 24, "Schema ID not preserved in range scan remaining key encoding") @@ -565,7 +567,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val encodedValue = valueEncoder.encodeValue(value) // Verify schema ID was included and preserved - val decodedSchemaIdRow = avroEncoder.decodeStateSchemaIdRow(encodedValue) + val decodedSchemaIdRow = avroEncoder.decodeStateSchemaIdRow( + avroEncoder.removeVersionByte(encodedValue)) assert(decodedSchemaIdRow.schemaId === 42, "Schema ID not preserved in single value encoding")