diff --git a/arrow-array/src/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs index 1d89d427aae1..012a454e76c9 100644 --- a/arrow-array/src/builder/map_builder.rs +++ b/arrow-array/src/builder/map_builder.rs @@ -61,6 +61,7 @@ pub struct MapBuilder { field_names: MapFieldNames, key_builder: K, value_builder: V, + key_field: Option, value_field: Option, } @@ -107,13 +108,27 @@ impl MapBuilder { field_names: field_names.unwrap_or_default(), key_builder, value_builder, + key_field: None, value_field: None, } } /// Override the field passed to [`MapBuilder::new`] /// - /// By default a nullable field is created with the name `values` + /// By default, a non-nullable field is created with the name `keys` + /// + /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the + /// field's data type does not match that of `K` or the field is nullable + pub fn with_keys_field(self, field: impl Into) -> Self { + Self { + key_field: Some(field.into()), + ..self + } + } + + /// Override the field passed to [`MapBuilder::new`] + /// + /// By default, a nullable field is created with the name `values` /// /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the /// field's data type does not match that of `V` @@ -194,11 +209,17 @@ impl MapBuilder { keys_arr.null_count() ); - let keys_field = Arc::new(Field::new( - self.field_names.key.as_str(), - keys_arr.data_type().clone(), - false, // always non-nullable - )); + let keys_field = match &self.key_field { + Some(f) => { + assert!(!f.is_nullable(), "Keys field must not be nullable"); + f.clone() + } + None => Arc::new(Field::new( + self.field_names.key.as_str(), + keys_arr.data_type().clone(), + false, // always non-nullable + )), + }; let values_field = match &self.value_field { Some(f) => f.clone(), None => Arc::new(Field::new( @@ -262,10 +283,10 @@ impl ArrayBuilder for MapBuilder { #[cfg(test)] mod tests { + use super::*; use crate::builder::{make_builder, Int32Builder, StringBuilder}; use crate::{Int32Array, StringArray}; - - use super::*; + use std::collections::HashMap; #[test] #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")] @@ -377,4 +398,67 @@ mod tests { ) ); } + + #[test] + fn test_with_keys_field() { + let mut key_metadata = HashMap::new(); + key_metadata.insert("foo".to_string(), "bar".to_string()); + let key_field = Arc::new( + Field::new("keys", DataType::Int32, false).with_metadata(key_metadata.clone()), + ); + let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) + .with_keys_field(key_field.clone()); + builder.keys().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + let map = builder.finish(); + + assert_eq!(map.len(), 1); + assert_eq!( + map.data_type(), + &DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Arc::new( + Field::new("keys", DataType::Int32, false) + .with_metadata(key_metadata) + ), + Arc::new(Field::new("values", DataType::Int32, true)) + ] + .into() + ), + false, + )), + false + ) + ); + } + + #[test] + #[should_panic(expected = "Keys field must not be nullable")] + fn test_with_nullable_keys_field() { + let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) + .with_keys_field(Arc::new(Field::new("keys", DataType::Int32, true))); + + builder.keys().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + + builder.finish(); + } + + #[test] + #[should_panic(expected = "Incorrect datatype")] + fn test_keys_field_type_mismatch() { + let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) + .with_keys_field(Arc::new(Field::new("keys", DataType::Utf8, false))); + + builder.keys().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + + builder.finish(); + } } diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index 4a40c2201746..5cebc6485e0c 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -296,10 +296,11 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box panic!("The field of Map data type {t:?} should has a child Struct field"), + t => panic!("The field of Map data type {t:?} should have a child Struct field"), }, DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), t @ DataType::Dictionary(key_type, value_type) => {