diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 46f09cd0aa2a..1bc9c170133b 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -310,7 +310,7 @@ impl Decoder { Self::Record(arrow_fields.into(), encodings) } (Codec::Map(child), _) => { - let val_field = child.field_with_name("value").with_nullable(true); + let val_field = child.field_with_name("value"); let map_field = Arc::new(ArrowField::new( "entries", DataType::Struct(Fields::from(vec![ @@ -590,10 +590,23 @@ impl Decoder { ))); } } + // Extract the value field nullability from the schema + let is_value_nullable = match map_field.data_type() { + DataType::Struct(fields) => fields + .iter() + .find(|f| f.name() == "value") + .map(|f| f.is_nullable()) + .unwrap_or(false), + _ => true, // default to nullable + }; let entries_struct = StructArray::new( Fields::from(vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), + Arc::new(ArrowField::new( + "value", + val_arr.data_type().clone(), + is_value_nullable, + )), ]), vec![Arc::new(key_arr), val_arr], None, @@ -740,6 +753,7 @@ fn sign_extend_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { mod tests { use super::*; use crate::codec::AvroField; + use crate::schema::Schema as AvroSchema; use arrow_array::{ cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, @@ -1471,4 +1485,111 @@ mod tests { assert!(int_array.is_null(0)); // row1 is null assert_eq!(int_array.value(1), 42); // row3 value is 42 } + + #[test] + fn test_map_with_non_nullable_value_type() { + let schema_json = r#"{ + "type": "record", + "name": "MapRecord", + "fields": [ + {"name": "map_field", "type": { "type": "map", "values": "string" }} + ] + }"#; + + let schema: AvroSchema = serde_json::from_str(schema_json).unwrap(); + let field = AvroField::try_from(&schema).unwrap(); + let mut decoder = RecordDecoder::try_new_with_options(field.data_type(), true).unwrap(); + + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); // 1 entry in map + data.extend_from_slice(&encode_avro_bytes(b"key")); // key + data.extend_from_slice(&encode_avro_bytes(b"value")); // value + data.extend_from_slice(&encode_avro_long(0)); // end map + + decoder.decode(&data, 1).unwrap(); + + let batch = decoder.flush().unwrap(); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 1); + + let map_arr = batch.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); + assert_eq!(map_arr.value_length(0), 1); + + let entries = map_arr.value(0); + let key_arr = entries.column(0).as_string::(); + let val_arr = entries.column(1).as_string::(); + assert_eq!(key_arr.value(0), "key"); + assert_eq!(val_arr.value(0), "value"); + } + + #[test] + fn test_map_with_nullable_value_type() { + let schema_json = r#"{ + "type": "record", + "name": "MapRecord", + "fields": [ + {"name": "map_field1", "type": { "type": "map", "values": ["null", "string"] }}, + {"name": "map_field2", "type": { "type": "map", "values": ["string", "null"] }} + ] + }"#; + + let schema: AvroSchema = serde_json::from_str(schema_json).unwrap(); + let field = AvroField::try_from(&schema).unwrap(); + let mut decoder = RecordDecoder::try_new_with_options(field.data_type(), true).unwrap(); + + let mut data = Vec::new(); + + // map_field1: ["null", "string"] + data.extend_from_slice(&encode_avro_long(2)); // 2 entries in map + // First entry: key1 -> null value (union branch 0) + data.extend_from_slice(&encode_avro_bytes(b"key1")); + data.extend_from_slice(&encode_avro_long(0)); // union branch 0 (null) + // Second entry: key2 -> "value2" (union branch 1) + data.extend_from_slice(&encode_avro_bytes(b"key2")); + data.extend_from_slice(&encode_avro_long(1)); // union branch 1 (string) + data.extend_from_slice(&encode_avro_bytes(b"value2")); + data.extend_from_slice(&encode_avro_long(0)); // end map + + // map_field2: ["string", "null"] + data.extend_from_slice(&encode_avro_long(2)); // 2 entries in map + // First entry: key3 -> null value (union branch 1) + data.extend_from_slice(&encode_avro_bytes(b"key3")); + data.extend_from_slice(&encode_avro_long(1)); // union branch 1 (null) + // Second entry: key4 -> "value4" (union branch 0) + data.extend_from_slice(&encode_avro_bytes(b"key4")); + data.extend_from_slice(&encode_avro_long(0)); // union branch 0 (string) + data.extend_from_slice(&encode_avro_bytes(b"value4")); + data.extend_from_slice(&encode_avro_long(0)); // end map + + decoder.decode(&data, 1).unwrap(); + + let batch = decoder.flush().unwrap(); + assert_eq!(batch.num_columns(), 2); + assert_eq!(batch.num_rows(), 1); + + // Check the first map field: ["null", "string"] + let map_arr1 = batch.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr1.len(), 1); + assert_eq!(map_arr1.value_length(0), 2); // 2 entries + let entries1 = map_arr1.value(0); + let key_arr1 = entries1.column(0).as_string::(); + let val_arr1 = entries1.column(1).as_string::(); + assert_eq!(key_arr1.value(0), "key1"); + assert!(val_arr1.is_null(0)); + assert_eq!(key_arr1.value(1), "key2"); + assert_eq!(val_arr1.value(1), "value2"); + + // Check second map field: ["string", "null"] + let map_arr2 = batch.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr2.len(), 1); + assert_eq!(map_arr2.value_length(0), 2); // 2 entries + let entries2 = map_arr2.value(0); + let key_arr2 = entries2.column(0).as_string::(); + let val_arr2 = entries2.column(1).as_string::(); + assert_eq!(key_arr2.value(0), "key3"); + assert!(val_arr2.is_null(0)); + assert_eq!(key_arr2.value(1), "key4"); + assert_eq!(val_arr2.value(1), "value4"); + } }