Skip to content

Commit adfcb1c

Browse files
committed
Refactored state machine approach to get 2x performance improvement.
1 parent 0878364 commit adfcb1c

File tree

1 file changed

+124
-187
lines changed

1 file changed

+124
-187
lines changed

arrow-avro/src/reader/mod.rs

Lines changed: 124 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@
8989
//! }
9090
//! ```
9191
//!
92-
9392
use crate::codec::{AvroField, AvroFieldBuilder};
9493
use crate::schema::{
9594
compare_schemas, generate_fingerprint, AvroSchema, Fingerprint, FingerprintAlgorithm, Schema,
@@ -140,15 +139,6 @@ fn is_incomplete_data(err: &ArrowError) -> bool {
140139
)
141140
}
142141

143-
#[derive(Debug)]
144-
enum DecoderState {
145-
Magic,
146-
Fingerprint,
147-
Record,
148-
SchemaChange,
149-
Finished,
150-
}
151-
152142
/// A low-level interface for decoding Avro-encoded bytes into Arrow `RecordBatch`.
153143
#[derive(Debug)]
154144
pub struct Decoder {
@@ -161,9 +151,7 @@ pub struct Decoder {
161151
utf8_view: bool,
162152
strict_mode: bool,
163153
pending_schema: Option<(Fingerprint, RecordDecoder)>,
164-
state: DecoderState,
165-
bytes_remaining: usize,
166-
fingerprint_buf: Vec<u8>,
154+
awaiting_body: bool,
167155
}
168156

169157
impl Decoder {
@@ -182,139 +170,135 @@ impl Decoder {
182170
/// - reach `batch_size` decoded rows.
183171
///
184172
/// Returns the number of bytes consumed.
185-
pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, ArrowError> {
186-
let max_read = buf.len();
187-
while !buf.is_empty() {
188-
match self.state {
189-
DecoderState::Magic => {
190-
self.fingerprint_buf.clear();
191-
let remaining =
192-
&SINGLE_OBJECT_MAGIC[SINGLE_OBJECT_MAGIC.len() - self.bytes_remaining..];
193-
let to_decode = buf.len().min(remaining.len());
194-
if !buf.starts_with(&remaining[..to_decode]) {
195-
return Err(ArrowError::ParseError(
196-
"Invalid avro single object encoding magic".to_string(),
197-
));
173+
pub fn decode(&mut self, data: &[u8]) -> Result<usize, ArrowError> {
174+
let mut total_consumed = 0usize;
175+
while total_consumed < data.len() && self.remaining_capacity > 0 {
176+
if !self.awaiting_body {
177+
if let Some(n) = self.handle_prefix(&data[total_consumed..])? {
178+
if n == 0 {
179+
break;
198180
}
199-
self.bytes_remaining -= to_decode;
200-
buf = &buf[to_decode..];
201-
if self.bytes_remaining == 0 {
202-
match self.fingerprint_algorithm {
203-
FingerprintAlgorithm::Rabin => {
204-
self.bytes_remaining = 8;
205-
}
206-
}
207-
self.state = DecoderState::Fingerprint;
181+
total_consumed += n;
182+
self.awaiting_body = true;
183+
if self.remaining_capacity == self.batch_size && self.pending_schema.is_some() {
184+
self.apply_pending_schema_if_batch_empty();
208185
}
209-
}
210-
DecoderState::Fingerprint => {
211-
let to_decode = self.bytes_remaining.min(buf.len());
212-
self.fingerprint_buf.extend_from_slice(&buf[..to_decode]);
213-
self.bytes_remaining -= to_decode;
214-
buf = &buf[to_decode..];
215-
if self.bytes_remaining == 0 {
216-
let fingerprint: Fingerprint = match self.fingerprint_algorithm {
217-
FingerprintAlgorithm::Rabin => Fingerprint::Rabin(u64::from_le_bytes(
218-
self.fingerprint_buf.as_slice().try_into().map_err(|e| {
219-
ArrowError::ParseError(format!(
220-
"Fingerprint buffer too small: {e}"
221-
))
222-
})?,
223-
)),
224-
};
225-
match self.active_fingerprint {
226-
Some(active_fp) if active_fp == fingerprint => {
227-
self.state = DecoderState::Record;
228-
}
229-
_ => match self.cache.shift_remove(&fingerprint) {
230-
Some(new_decoder) => {
231-
self.pending_schema = Some((fingerprint, new_decoder));
232-
self.state = if self.remaining_capacity < self.batch_size {
233-
self.remaining_capacity = 0;
234-
DecoderState::Finished
235-
} else {
236-
DecoderState::SchemaChange
237-
};
238-
}
239-
None => {
240-
return Err(ArrowError::ParseError(format!(
241-
"Unknown fingerprint: {fingerprint:?}"
242-
)));
243-
}
244-
},
245-
}
186+
if self.remaining_capacity == 0 {
187+
break;
246188
}
247189
}
248-
DecoderState::Record => match self.active_decoder.decode(buf, 1) {
249-
Ok(n) if n > 0 => {
250-
self.remaining_capacity -= 1;
251-
buf = &buf[n..];
252-
if self.remaining_capacity == 0 {
253-
self.state = DecoderState::Finished;
254-
} else {
255-
self.bytes_remaining = SINGLE_OBJECT_MAGIC.len();
256-
self.state = DecoderState::Magic;
257-
}
258-
}
259-
Ok(_) => {
260-
return Err(ArrowError::ParseError(
261-
"Record decoder consumed 0 bytes".into(),
262-
));
263-
}
264-
Err(e) => {
265-
if !is_incomplete_data(&e) {
266-
return Err(e);
267-
}
268-
return Ok(max_read - buf.len());
269-
}
270-
},
271-
DecoderState::SchemaChange => {
272-
if let Some((new_fingerprint, new_decoder)) = self.pending_schema.take() {
273-
if let Some(old_fingerprint) =
274-
self.active_fingerprint.replace(new_fingerprint)
275-
{
276-
let old_decoder =
277-
std::mem::replace(&mut self.active_decoder, new_decoder);
278-
self.cache.shift_remove(&old_fingerprint);
279-
self.cache.insert(old_fingerprint, old_decoder);
280-
} else {
281-
self.active_decoder = new_decoder;
282-
}
190+
}
191+
match self.active_decoder.decode(&data[total_consumed..], 1) {
192+
Ok(n) if n > 0 => {
193+
self.remaining_capacity -= 1;
194+
total_consumed += n;
195+
self.awaiting_body = false;
196+
}
197+
Ok(_) => {
198+
return Err(ArrowError::ParseError(
199+
"Record decoder consumed 0 bytes".into(),
200+
));
201+
}
202+
Err(e) => {
203+
return if is_incomplete_data(&e) {
204+
Ok(total_consumed)
205+
} else {
206+
Err(e)
283207
}
284-
self.state = DecoderState::Record;
285208
}
286-
DecoderState::Finished => return Ok(max_read - buf.len()),
287209
}
288210
}
289-
Ok(max_read)
211+
Ok(total_consumed)
212+
}
213+
214+
// Attempt to handle a single‑object‑encoding prefix at the current position.
215+
//
216+
// * Ok(None) – buffer does not start with the prefix.
217+
// * Ok(Some(0)) – prefix detected, but the buffer is too short; caller should await more bytes.
218+
// * Ok(Some(n)) – consumed `n > 0` bytes of a complete prefix (magic and fingerprint).
219+
fn handle_prefix(&mut self, buf: &[u8]) -> Result<Option<usize>, ArrowError> {
220+
// Need at least the magic bytes to decide (2 bytes).
221+
let Some(magic_bytes) = buf.get(..SINGLE_OBJECT_MAGIC.len()) else {
222+
return Ok(Some(0)); // Get more bytes
223+
};
224+
// Bail out early if the magic does not match.
225+
if magic_bytes != SINGLE_OBJECT_MAGIC {
226+
return Ok(None); // Continue to decode the next record
227+
}
228+
// Try to parse the fingerprint that follows the magic.
229+
let fingerprint_size = match self.fingerprint_algorithm {
230+
FingerprintAlgorithm::Rabin => self
231+
.handle_fingerprint(&buf[SINGLE_OBJECT_MAGIC.len()..], |bytes| {
232+
Fingerprint::Rabin(u64::from_le_bytes(bytes))
233+
})?,
234+
};
235+
// Convert the inner result into a “bytes consumed” count.
236+
// NOTE: Incomplete fingerprint consumes no bytes.
237+
let consumed = fingerprint_size.map_or(0, |n| n + SINGLE_OBJECT_MAGIC.len());
238+
Ok(Some(consumed))
239+
}
240+
241+
// Attempts to read and install a new fingerprint of `N` bytes.
242+
//
243+
// * Ok(None) – insufficient bytes (`buf.len() < `N`).
244+
// * Ok(Some(N)) – fingerprint consumed (always `N`).
245+
fn handle_fingerprint<const N: usize>(
246+
&mut self,
247+
buf: &[u8],
248+
fingerprint_from: impl FnOnce([u8; N]) -> Fingerprint,
249+
) -> Result<Option<usize>, ArrowError> {
250+
// Need enough bytes to get fingerprint (next N bytes)
251+
let Some(fingerprint_bytes) = buf.get(..N) else {
252+
return Ok(None); // Insufficient bytes
253+
};
254+
// SAFETY: length checked above.
255+
let new_fingerprint = fingerprint_from(fingerprint_bytes.try_into().unwrap());
256+
// If the fingerprint indicates a schema change, prepare to switch decoders.
257+
if self.active_fingerprint != Some(new_fingerprint) {
258+
let Some(new_decoder) = self.cache.shift_remove(&new_fingerprint) else {
259+
return Err(ArrowError::ParseError(format!(
260+
"Unknown fingerprint: {new_fingerprint:?}"
261+
)));
262+
};
263+
self.pending_schema = Some((new_fingerprint, new_decoder));
264+
// If there are already decoded rows, we must flush them first.
265+
// Reducing `remaining_capacity` to 0 ensures `flush` is called next.
266+
if self.remaining_capacity < self.batch_size {
267+
self.remaining_capacity = 0;
268+
}
269+
}
270+
Ok(Some(N))
271+
}
272+
273+
fn apply_pending_schema(&mut self) {
274+
if let Some((new_fingerprint, new_decoder)) = self.pending_schema.take() {
275+
if let Some(old_fingerprint) = self.active_fingerprint.replace(new_fingerprint) {
276+
let old_decoder = std::mem::replace(&mut self.active_decoder, new_decoder);
277+
self.cache.shift_remove(&old_fingerprint);
278+
self.cache.insert(old_fingerprint, old_decoder);
279+
} else {
280+
self.active_decoder = new_decoder;
281+
}
282+
}
283+
}
284+
285+
fn apply_pending_schema_if_batch_empty(&mut self) {
286+
if self.remaining_capacity != self.batch_size {
287+
return;
288+
}
289+
self.apply_pending_schema();
290290
}
291291

292292
/// Produce a `RecordBatch` if at least one row is fully decoded, returning
293293
/// `Ok(None)` if no new rows are available.
294294
pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
295-
match self.state {
296-
DecoderState::Finished => {
297-
let batch = self.active_decoder.flush()?;
298-
self.remaining_capacity = self.batch_size;
299-
if self.pending_schema.is_some() {
300-
self.state = DecoderState::SchemaChange;
301-
} else {
302-
self.bytes_remaining = SINGLE_OBJECT_MAGIC.len();
303-
self.state = DecoderState::Magic;
304-
}
305-
Ok(Some(batch))
306-
}
307-
DecoderState::Magic | DecoderState::Record => {
308-
if self.remaining_capacity < self.batch_size {
309-
let batch = self.active_decoder.flush()?;
310-
self.remaining_capacity = self.batch_size;
311-
Ok(Some(batch))
312-
} else {
313-
Ok(None)
314-
}
315-
}
316-
_ => Ok(None),
295+
if self.remaining_capacity == self.batch_size {
296+
return Ok(None);
317297
}
298+
let batch = self.active_decoder.flush()?;
299+
self.remaining_capacity = self.batch_size;
300+
self.apply_pending_schema();
301+
Ok(Some(batch))
318302
}
319303

320304
/// Returns the number of rows that can be added to this decoder before it is full.
@@ -423,9 +407,7 @@ impl ReaderBuilder {
423407
fingerprint_algorithm,
424408
strict_mode: self.strict_mode,
425409
pending_schema: None,
426-
state: DecoderState::Magic,
427-
bytes_remaining: SINGLE_OBJECT_MAGIC.len(),
428-
fingerprint_buf: Vec::new(),
410+
awaiting_body: false,
429411
}
430412
}
431413

@@ -839,48 +821,13 @@ mod test {
839821
);
840822
}
841823

842-
#[test]
843-
fn test_missing_initial_fingerprint_error() {
844-
let (store, _fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store();
845-
let mut decoder = ReaderBuilder::new()
846-
.with_batch_size(8)
847-
.with_reader_schema(schema_int.clone())
848-
.with_writer_schema_store(store)
849-
.build_decoder()
850-
.unwrap();
851-
let buf = [0x02u8, 0x00u8];
852-
let err = decoder.decode(&buf).expect_err("decode should error");
853-
let msg = err.to_string();
854-
assert!(
855-
msg.contains("Invalid avro single object encoding magic"),
856-
"unexpected message: {msg}"
857-
);
858-
}
859-
860-
#[test]
861-
fn test_handle_prefix_no_schema_store() {
862-
let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store();
863-
let mut decoder = make_decoder(&store, fp_int, &schema_int);
864-
let consumed = decoder
865-
.decode(&SINGLE_OBJECT_MAGIC[..])
866-
.expect("decode magic");
867-
assert_eq!(consumed, SINGLE_OBJECT_MAGIC.len());
868-
assert!(matches!(decoder.state, super::DecoderState::Fingerprint));
869-
match decoder.fingerprint_algorithm {
870-
FingerprintAlgorithm::Rabin => assert_eq!(decoder.bytes_remaining, 8),
871-
}
872-
assert!(decoder.pending_schema.is_none());
873-
}
874-
875824
#[test]
876825
fn test_handle_prefix_incomplete_magic() {
877826
let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store();
878827
let mut decoder = make_decoder(&store, fp_int, &schema_int);
879828
let buf = &SINGLE_OBJECT_MAGIC[..1];
880-
let consumed = decoder.decode(buf).unwrap();
881-
assert_eq!(consumed, buf.len());
882-
assert!(matches!(decoder.state, super::DecoderState::Magic));
883-
assert_eq!(decoder.bytes_remaining, SINGLE_OBJECT_MAGIC.len() - 1);
829+
let res = decoder.handle_prefix(buf).unwrap();
830+
assert_eq!(res, Some(0));
884831
assert!(decoder.pending_schema.is_none());
885832
}
886833

@@ -889,12 +836,8 @@ mod test {
889836
let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store();
890837
let mut decoder = make_decoder(&store, fp_int, &schema_int);
891838
let buf = [0xFFu8, 0x00u8, 0x01u8];
892-
let err = decoder.decode(&buf).expect_err("decode should error");
893-
let msg = err.to_string();
894-
assert!(
895-
msg.contains("Invalid avro single object encoding magic"),
896-
"unexpected message: {msg}"
897-
);
839+
let res = decoder.handle_prefix(&buf).unwrap();
840+
assert!(res.is_none());
898841
}
899842

900843
#[test]
@@ -906,13 +849,8 @@ mod test {
906849
};
907850
let mut buf = Vec::from(SINGLE_OBJECT_MAGIC);
908851
buf.extend_from_slice(&long_bytes[..4]);
909-
let consumed = decoder.decode(&buf).unwrap();
910-
assert_eq!(consumed, buf.len());
911-
assert!(matches!(decoder.state, super::DecoderState::Fingerprint));
912-
assert_eq!(decoder.fingerprint_buf.len(), 4);
913-
match decoder.fingerprint_algorithm {
914-
FingerprintAlgorithm::Rabin => assert_eq!(decoder.bytes_remaining, 4),
915-
}
852+
let res = decoder.handle_prefix(&buf).unwrap();
853+
assert_eq!(res, Some(0));
916854
assert!(decoder.pending_schema.is_none());
917855
}
918856

@@ -928,11 +866,10 @@ mod test {
928866
let mut buf = Vec::from(SINGLE_OBJECT_MAGIC);
929867
let Fingerprint::Rabin(v) = fp_long;
930868
buf.extend_from_slice(&v.to_le_bytes());
931-
let consumed = decoder.decode(&buf).unwrap();
869+
let consumed = decoder.handle_prefix(&buf).unwrap().unwrap();
932870
assert_eq!(consumed, buf.len());
933871
assert!(decoder.pending_schema.is_some());
934872
assert_eq!(decoder.pending_schema.as_ref().unwrap().0, fp_long);
935-
assert!(matches!(decoder.state, super::DecoderState::SchemaChange));
936873
}
937874

938875
#[test]

0 commit comments

Comments
 (0)