From 55590118f1492addfbb40a736bdf7d67123d303f Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 13 Nov 2025 16:48:25 -0500 Subject: [PATCH 1/2] Initial implementation of union row converter --- arrow-row/src/lib.rs | 383 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 383 insertions(+) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 5f690e9a6734..33317569f2c1 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -458,6 +458,9 @@ enum Codec { List(RowConverter), /// A row converter for the values array of a run-end encoded array RunEndEncoded(RowConverter), + /// Row converters for each union field (indexed by type_id) + /// and the encoding of null rows for each field + Union(Vec, Vec, UnionMode), } impl Codec { @@ -524,6 +527,35 @@ impl Codec { Ok(Self::Struct(converter, owned)) } + DataType::Union(fields, mode) => { + // similar to dictionaries and lists, we set descending to false and negate nulls_first + // since the encodedc ontents will be inverted if descending is set + let options = SortOptions { + descending: false, + nulls_first: sort_field.options.nulls_first != sort_field.options.descending, + }; + + let mut converters = Vec::with_capacity(fields.len()); + let mut null_rows = Vec::with_capacity(fields.len()); + + for (_type_id, field) in fields.iter() { + let sort_field = + SortField::new_with_options(field.data_type().clone(), options); + let converter = RowConverter::new(vec![sort_field])?; + + let null_array = new_null_array(field.data_type(), 1); + let nulls = converter.convert_columns(&[null_array])?; + let owned = OwnedRow { + data: nulls.buffer.into(), + config: nulls.config, + }; + + converters.push(converter); + null_rows.push(owned); + } + + Ok(Self::Union(converters, null_rows, *mode)) + } _ => Err(ArrowError::NotYetImplemented(format!( "not yet implemented: {:?}", sort_field.data_type @@ -592,6 +624,29 @@ impl Codec { let rows = converter.convert_columns(std::slice::from_ref(values))?; Ok(Encoder::RunEndEncoded(rows)) } + Codec::Union(converters, _, mode) => { + let union_array = array + .as_any() + .downcast_ref::() + .expect("expected Union array"); + + let type_ids = union_array.type_ids().clone(); + let offsets = union_array.offsets().cloned(); + + let mut child_rows = Vec::with_capacity(converters.len()); + for (type_id, converter) in converters.iter().enumerate() { + let child_array = union_array.child(type_id as i8); + let rows = converter.convert_columns(std::slice::from_ref(child_array))?; + child_rows.push(rows); + } + + Ok(Encoder::Union { + child_rows, + type_ids, + offsets, + mode: *mode, + }) + } } } @@ -602,6 +657,10 @@ impl Codec { Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(), Codec::List(converter) => converter.size(), Codec::RunEndEncoded(converter) => converter.size(), + Codec::Union(converters, null_rows, _) => { + converters.iter().map(|c| c.size()).sum::() + + null_rows.iter().map(|n| n.data.len()).sum::() + } } } } @@ -622,6 +681,13 @@ enum Encoder<'a> { List(Rows), /// The row encoding of the values array RunEndEncoded(Rows), + /// The row encoding of each union field's child array, type_ids buffer, offsets buffer (for Dense), and mode + Union { + child_rows: Vec, + type_ids: ScalarBuffer, + offsets: Option>, + mode: UnionMode, + }, } /// Configure the data type and sort order for a given column @@ -681,6 +747,9 @@ impl RowConverter { } DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())), DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()), + DataType::Union(fs, _mode) => fs + .iter() + .all(|(_, f)| Self::supports_datatype(f.data_type())), _ => false, } } @@ -1523,6 +1592,33 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { }, _ => unreachable!(), }, + Encoder::Union { + child_rows, + type_ids, + offsets, + mode, + } => { + let union_array = array.as_any().downcast_ref::().unwrap(); + + let lengths = (0..union_array.len()).map(|i| { + let type_id = type_ids[i]; + + let child_row_i = match (mode, offsets) { + (UnionMode::Dense, Some(offsets)) => offsets[i] as usize, + (UnionMode::Sparse, None) => i, + foreign => { + unreachable!("invalid union mode/offsets combination: {foreign:?}") + } + }; + + let child_row = child_rows[type_id as usize].row(child_row_i); + + // length: 1 byte null sentinel + 1 byte type_id + child row bytes + 1 + 1 + child_row.as_ref().len() + }); + + tracker.push_variable(lengths); + } } } @@ -1637,6 +1733,49 @@ fn encode_column( }, _ => unreachable!(), }, + Encoder::Union { + child_rows, + type_ids, + offsets: offsets_buf, + mode, + } => { + let _union_array = column.as_any().downcast_ref::().unwrap(); + let null_sentinel = if opts.descending { 0x00 } else { 0x01 }; + + offsets + .iter_mut() + .skip(1) + .enumerate() + .for_each(|(i, offset)| { + let type_id = type_ids[i]; + + let child_row_idx = match (mode, offsets_buf) { + (UnionMode::Dense, Some(o)) => o[i] as usize, + (UnionMode::Sparse, None) => i, + foreign => { + unreachable!("invalid union mode/offsets combination: {foreign:?}") + } + }; + + let child_row = child_rows[type_id as usize].row(child_row_idx); + let child_bytes = child_row.as_ref(); + + data[*offset] = null_sentinel; + + let type_id_byte = if opts.descending { + !(type_id as u8) + } else { + type_id as u8 + }; + data[*offset + 1] = type_id_byte; + + let child_start = *offset + 2; + let child_end = child_start + child_bytes.len(); + data[child_start..child_end].copy_from_slice(child_bytes); + + *offset = child_end; + }); + } } } @@ -1762,6 +1901,110 @@ unsafe fn decode_column( }, _ => unreachable!(), }, + Codec::Union(converters, null_rows, _mode) => { + let len = rows.len(); + + let DataType::Union(union_fields, mode) = &field.data_type else { + unreachable!() + }; + + let mut type_ids = Vec::with_capacity(len); + let mut rows_by_field: Vec> = vec![Vec::new(); converters.len()]; + + for (idx, row) in rows.iter_mut().enumerate() { + // skip the null sentinel + let mut cursor = 1; + + let type_id_byte = { + let id = row[cursor]; + cursor += 1; + + if options.descending { !id } else { id } + }; + + let type_id = type_id_byte as i8; + type_ids.push(type_id); + + let field_idx = type_id as usize; + + let child_row = &row[cursor..]; + rows_by_field[field_idx].push((idx, child_row)); + + *row = &row[row.len()..]; + } + + let mut child_arrays: Vec = Vec::with_capacity(converters.len()); + + let mut offsets = (*mode == UnionMode::Dense).then(|| Vec::with_capacity(len)); + + for (field_idx, converter) in converters.iter().enumerate() { + let field_rows = &rows_by_field[field_idx]; + + match &mode { + UnionMode::Dense => { + if field_rows.is_empty() { + let (_, field) = union_fields.iter().nth(field_idx).unwrap(); + child_arrays.push(arrow_array::new_empty_array(field.data_type())); + continue; + } + + let mut child_data = field_rows + .iter() + .map(|(_, bytes)| *bytes) + .collect::>(); + + let child_array = + unsafe { converter.convert_raw(&mut child_data, validate_utf8) }?; + + child_arrays.push(child_array.into_iter().next().unwrap()); + } + UnionMode::Sparse => { + let mut sparse_data: Vec<&[u8]> = Vec::with_capacity(len); + let mut field_row_iter = field_rows.iter().peekable(); + let null_row_bytes: &[u8] = &null_rows[field_idx].data; + + for idx in 0..len { + if let Some((next_idx, bytes)) = field_row_iter.peek() { + if *next_idx == idx { + sparse_data.push(*bytes); + + field_row_iter.next(); + continue; + } + } + sparse_data.push(null_row_bytes); + } + + let child_array = + unsafe { converter.convert_raw(&mut sparse_data, validate_utf8) }?; + child_arrays.push(child_array.into_iter().next().unwrap()); + } + } + } + + // build offsets for dense unions + if let Some(ref mut offsets_vec) = offsets { + let mut count = vec![0i32; converters.len()]; + for type_id in &type_ids { + let field_idx = *type_id as usize; + offsets_vec.push(count[field_idx]); + + count[field_idx] += 1; + } + } + + let type_ids_buffer = ScalarBuffer::from(type_ids); + let offsets_buffer = offsets.map(ScalarBuffer::from); + + let union_array = UnionArray::try_new( + union_fields.clone(), + type_ids_buffer, + offsets_buffer, + child_arrays, + )?; + + Arc::new(union_array) + } }; Ok(array) } @@ -3598,4 +3841,144 @@ mod tests { assert_eq!(unchecked_values_len, 13); assert!(checked_values_len > unchecked_values_len); } + + #[test] + fn test_sparse_union() { + // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 1) + let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let str_array = StringArray::from(vec![None, Some("b"), None, Some("d"), None]); + + // [1, "b", 3, "d", 5] + let type_ids = vec![0, 1, 0, 1, 0].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, false))), + (1, Arc::new(Field::new("str", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(union_array.clone())]) + .unwrap(); + + // round trip + let back = converter.convert_rows(&rows).unwrap(); + let back_union = back[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(union_array.len(), back_union.len()); + for i in 0..union_array.len() { + assert_eq!(union_array.type_id(i), back_union.type_id(i)); + } + } + + #[test] + fn test_dense_union() { + // create a dense union with Int32 (type_id = 0) and use Utf8 (type_id = 1) + let int_array = Int32Array::from(vec![1, 3, 5]); + let str_array = StringArray::from(vec!["a", "b"]); + + let type_ids = vec![0, 1, 0, 1, 0].into(); + + // [1, "a", 3, "b", 5] + let offsets = vec![0, 0, 1, 1, 2].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, false))), + (1, Arc::new(Field::new("str", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), // Dense mode + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(union_array.clone())]) + .unwrap(); + + // round trip + let back = converter.convert_rows(&rows).unwrap(); + let back_union = back[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(union_array.len(), back_union.len()); + for i in 0..union_array.len() { + assert_eq!(union_array.type_id(i), back_union.type_id(i)); + } + } + + #[test] + fn test_union_ordering() { + let int_array = Int32Array::from(vec![100, 5, 20]); + let str_array = StringArray::from(vec!["z", "a"]); + + // [100, "z", 5, "a", 20] + let type_ids = vec![0, 1, 0, 1, 0].into(); + let offsets = vec![0, 0, 1, 1, 2].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, false))), + (1, Arc::new(Field::new("str", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter.convert_columns(&[Arc::new(union_array)]).unwrap(); + + /* + expected ordering + + row 2: 5 - type_id 0 + row 4: 20 - type_id 0 + row 0: 100 - type id 0 + row 3: "a" - type id 1 + row 1: "z" - type id 1 + */ + + // 5 < "z" + assert!(rows.row(2) < rows.row(1)); + + // 100 < "a" + assert!(rows.row(0) < rows.row(3)); + + // among ints + // 5 < 20 + assert!(rows.row(2) < rows.row(4)); + // 20 < 100 + assert!(rows.row(4) < rows.row(0)); + + // among strigns + // "a" < "z" + assert!(rows.row(3) < rows.row(1)); + } } From e365f923f64f277c23a5122fa8f76396fc6029ea Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Fri, 14 Nov 2025 10:15:17 -0500 Subject: [PATCH 2/2] Properly encode null sentinel --- arrow-row/src/lib.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 33317569f2c1..c26e54e27861 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1739,14 +1739,20 @@ fn encode_column( offsets: offsets_buf, mode, } => { - let _union_array = column.as_any().downcast_ref::().unwrap(); - let null_sentinel = if opts.descending { 0x00 } else { 0x01 }; + let union_array = as_union_array(column); + let null_sentinel = null_sentinel(opts); offsets .iter_mut() .skip(1) .enumerate() .for_each(|(i, offset)| { + let sentinel = if union_array.is_valid(i) { + 0x01 + } else { + null_sentinel + }; + let type_id = type_ids[i]; let child_row_idx = match (mode, offsets_buf) { @@ -1760,7 +1766,7 @@ fn encode_column( let child_row = child_rows[type_id as usize].row(child_row_idx); let child_bytes = child_row.as_ref(); - data[*offset] = null_sentinel; + data[*offset] = sentinel; let type_id_byte = if opts.descending { !(type_id as u8)