diff --git a/crates/iceberg/src/arrow/record_batch_projector.rs b/crates/iceberg/src/arrow/record_batch_projector.rs index 878d0fe28e..2fdf0dd69c 100644 --- a/crates/iceberg/src/arrow/record_batch_projector.rs +++ b/crates/iceberg/src/arrow/record_batch_projector.rs @@ -15,139 +15,357 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; -use arrow_array::{make_array, ArrayRef, RecordBatch, StructArray}; +use arrow_array::{ + make_array, Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array, + Float64Array, Int32Array, Int64Array, LargeListArray, ListArray, MapArray, RecordBatch, + RecordBatchOptions, StringArray, StructArray, +}; use arrow_buffer::NullBuffer; -use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; +use arrow_cast::cast; +use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; +use super::type_to_arrow_type; use crate::error::Result; +use crate::spec::{ + visit_struct_with_partner, Literal, MapType, NestedField, PartnerAccessor, PrimitiveLiteral, + PrimitiveType, Schema, SchemaWithPartnerVisitor, StructType, Type, +}; use crate::{Error, ErrorKind}; -/// Help to project specific field from `RecordBatch`` according to the fields id. -#[derive(Clone, Debug)] -pub(crate) struct RecordBatchProjector { - // A vector of vectors, where each inner vector represents the index path to access a specific field in a nested structure. - // E.g. [[0], [1, 2]] means the first field is accessed directly from the first column, - // while the second field is accessed from the second column and then from its third subcolumn (second column must be a struct column). - field_indices: Vec>, - // The schema reference after projection. This schema is derived from the original schema based on the given field IDs. - projected_schema: SchemaRef, +/// This accessor used to search the field index path for each field in the schema. +/// +/// # Limit of this accessor: +/// - The accessor will not search the key field of the map type or the +/// value field of the map type. It will can search the map field itself. +/// - The accessor will not search the element field of the list type. It +/// will only search the list field itself. +struct FieldIndexPathAccessor { + field_id_fetch_func: F, + // For primitive field, if the field not found and this flag is true, the accessor will + // return the default primitive type for the field rather than raise an error. + // This is used for the default value process in iceberg. + allow_default_primitive: bool, } -impl RecordBatchProjector { - /// Init ArrowFieldProjector - /// - /// This function will iterate through the field and fetch the field from the original schema according to the field ids. - /// The function to fetch the field id from the field is provided by `field_id_fetch_func`, return None if the field need to be skipped. - /// This function will iterate through the nested fields if the field is a struct, `searchable_field_func` can be used to control whether - /// iterate into the nested fields. - pub(crate) fn new( - original_schema: SchemaRef, - field_ids: &[i32], - field_id_fetch_func: F1, - searchable_field_func: F2, - ) -> Result - where - F1: Fn(&Field) -> Result>, - F2: Fn(&Field) -> bool, - { - let mut field_indices = Vec::with_capacity(field_ids.len()); - let mut fields = Vec::with_capacity(field_ids.len()); - for &id in field_ids { - let mut field_index = vec![]; - let field = Self::fetch_field_index( - original_schema.fields(), - &mut field_index, - id as i64, - &field_id_fetch_func, - &searchable_field_func, - )? - .ok_or_else(|| { - Error::new(ErrorKind::Unexpected, "Field not found") - .with_context("field_id", id.to_string()) - })?; - fields.push(field.clone()); - field_indices.push(field_index); +impl FieldIndexPathAccessor +where F: Fn(&Field) -> Result +{ + pub fn new(field_id_fetch_func: F, ignore_not_found: bool) -> Self { + Self { + field_id_fetch_func, + allow_default_primitive: ignore_not_found, } - let delete_arrow_schema = Arc::new(Schema::new(fields)); - Ok(Self { - field_indices, - projected_schema: delete_arrow_schema, - }) } - fn fetch_field_index( + fn fetch_field_index_path( + &self, fields: &Fields, index_vec: &mut Vec, - target_field_id: i64, - field_id_fetch_func: &F1, - searchable_field_func: &F2, - ) -> Result> - where - F1: Fn(&Field) -> Result>, - F2: Fn(&Field) -> bool, - { + target_field_id: i32, + ) -> Result> { for (pos, field) in fields.iter().enumerate() { - let id = field_id_fetch_func(field)?; - if let Some(id) = id { - if target_field_id == id { - index_vec.push(pos); - return Ok(Some(field.clone())); - } + let id = (self.field_id_fetch_func)(field)?; + if target_field_id == id { + index_vec.push(pos); + return Ok(Some(field.data_type().clone())); } if let DataType::Struct(inner) = field.data_type() { - if searchable_field_func(field) { - if let Some(res) = Self::fetch_field_index( - inner, - index_vec, - target_field_id, - field_id_fetch_func, - searchable_field_func, - )? { - index_vec.push(pos); - return Ok(Some(res)); - } + if let Some(res) = self.fetch_field_index_path(inner, index_vec, target_field_id)? { + index_vec.push(pos); + return Ok(Some(res)); } } } Ok(None) } +} + +#[derive(Clone, Debug)] +enum FieldPath { + IndexPath(Vec, DataType), + Default(DataType), +} + +impl FieldPath { + fn data_type(&self) -> &DataType { + match self { + FieldPath::IndexPath(_, data_type) => data_type, + FieldPath::Default(data_type) => data_type, + } + } +} + +impl Result> PartnerAccessor for FieldIndexPathAccessor { + fn struct_parner(&self, schema_partner: &FieldPath) -> Result { + if !matches!(schema_partner.data_type(), DataType::Struct(_)) { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a struct")); + } + Ok(schema_partner.clone()) + } + + fn field_partner( + &self, + struct_partner: &FieldPath, + field: &crate::spec::NestedField, + ) -> Result { + let DataType::Struct(struct_fields) = &struct_partner.data_type() else { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a struct")); + }; + let mut index_path = vec![]; + let Some(field) = self.fetch_field_index_path(struct_fields, &mut index_path, field.id)? + else { + if self.allow_default_primitive && field.field_type.is_primitive() { + return Ok(FieldPath::Default(type_to_arrow_type(&field.field_type)?)); + } else { + return Err(Error::new(ErrorKind::Unexpected, "Field not found") + .with_context("target_id", field.id.to_string()) + .with_context("struct fields", format!("{:?}", struct_fields))); + } + }; + Ok(FieldPath::IndexPath(index_path, field)) + } + + fn list_element_partner(&self, list_partner: &FieldPath) -> Result { + match &list_partner.data_type() { + DataType::List(field) => Ok(FieldPath::Default(field.data_type().clone())), + DataType::LargeList(field) => Ok(FieldPath::Default(field.data_type().clone())), + DataType::FixedSizeList(field, _) => Ok(FieldPath::Default(field.data_type().clone())), + _ => Err(Error::new(ErrorKind::Unexpected, "Field is not a list")), + } + } + + fn map_key_partner(&self, map_partner: &FieldPath) -> Result { + let DataType::Map(field, _) = map_partner.data_type() else { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a map")); + }; + let DataType::Struct(inner_struct_fields) = field.data_type() else { + return Err(Error::new( + ErrorKind::Unexpected, + "inner field of map is not a struct", + )); + }; + Ok(FieldPath::Default( + inner_struct_fields[0].data_type().clone(), + )) + } + + fn map_value_partner(&self, map_partner: &FieldPath) -> Result { + let DataType::Map(field, _) = &map_partner.data_type() else { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a map")); + }; + let DataType::Struct(inner_struct_fields) = field.data_type() else { + return Err(Error::new( + ErrorKind::Unexpected, + "inner field of map is not a struct", + )); + }; + Ok(FieldPath::Default( + inner_struct_fields[1].data_type().clone(), + )) + } +} + +/// This visitor combine with `FieldIndexPathAccessor` to search the field index path for each field in the schema +/// and collect them into a hashmap. +struct FieldIndexPathCollector { + field_index_path_map: HashMap>, +} + +impl SchemaWithPartnerVisitor for FieldIndexPathCollector { + type T = (); + + fn schema( + &mut self, + _schema: &crate::spec::Schema, + _partner: &FieldPath, + _value: Self::T, + ) -> Result { + Ok(()) + } - /// Return the reference of projected schema - pub(crate) fn projected_schema_ref(&self) -> &SchemaRef { - &self.projected_schema + fn field( + &mut self, + field: &crate::spec::NestedFieldRef, + partner: &FieldPath, + _value: Self::T, + ) -> Result { + match partner { + FieldPath::IndexPath(index_path, _) => { + self.field_index_path_map + .insert(field.id, index_path.clone()); + } + FieldPath::Default(_) => { + // Ignore the default field + } + } + Ok(()) } - /// Do projection with record batch - pub(crate) fn project_batch(&self, batch: RecordBatch) -> Result { - RecordBatch::try_new( - self.projected_schema.clone(), - self.project_column(batch.columns())?, - ) - .map_err(|err| Error::new(ErrorKind::DataInvalid, format!("{err}"))) + fn r#struct( + &mut self, + _struct: &StructType, + _partner: &FieldPath, + _results: Vec, + ) -> Result { + Ok(()) } - /// Do projection with columns - pub(crate) fn project_column(&self, batch: &[ArrayRef]) -> Result> { - self.field_indices - .iter() - .map(|index_vec| Self::get_column_by_field_index(batch, index_vec)) - .collect::>>() + fn list( + &mut self, + _list: &crate::spec::ListType, + _partner: &FieldPath, + _value: Self::T, + ) -> Result { + Ok(()) } - fn get_column_by_field_index(batch: &[ArrayRef], field_index: &[usize]) -> Result { - let mut rev_iterator = field_index.iter().rev(); - let mut array = batch[*rev_iterator.next().unwrap()].clone(); + fn map( + &mut self, + _map: &crate::spec::MapType, + _partner: &FieldPath, + _key_value: Self::T, + _value: Self::T, + ) -> Result { + Ok(()) + } + + fn primitive( + &mut self, + _p: &crate::spec::PrimitiveType, + _partner: &FieldPath, + ) -> Result { + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct DefaultValueGenerator; + +impl DefaultValueGenerator { + fn create_column( + &self, + target_type: &PrimitiveType, + prim_lit: &Option, + num_rows: usize, + ) -> Result { + Ok(match (target_type, prim_lit) { + (PrimitiveType::Boolean, Some(PrimitiveLiteral::Boolean(value))) => { + Arc::new(BooleanArray::from(vec![*value; num_rows])) + } + (PrimitiveType::Boolean, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(BooleanArray::from(vals)) + } + (PrimitiveType::Int, Some(PrimitiveLiteral::Int(value))) => { + Arc::new(Int32Array::from(vec![*value; num_rows])) + } + (PrimitiveType::Int, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(Int32Array::from(vals)) + } + (PrimitiveType::Long, Some(PrimitiveLiteral::Long(value))) => { + Arc::new(Int64Array::from(vec![*value; num_rows])) + } + (PrimitiveType::Long, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(Int64Array::from(vals)) + } + (PrimitiveType::Float, Some(PrimitiveLiteral::Float(value))) => { + Arc::new(Float32Array::from(vec![value.0; num_rows])) + } + (PrimitiveType::Float, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(Float32Array::from(vals)) + } + (PrimitiveType::Double, Some(PrimitiveLiteral::Double(value))) => { + Arc::new(Float64Array::from(vec![value.0; num_rows])) + } + (PrimitiveType::Double, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(Float64Array::from(vals)) + } + (PrimitiveType::String, Some(PrimitiveLiteral::String(value))) => { + Arc::new(StringArray::from(vec![value.clone(); num_rows])) + } + (PrimitiveType::String, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(StringArray::from(vals)) + } + (PrimitiveType::Binary, Some(PrimitiveLiteral::Binary(value))) => { + Arc::new(BinaryArray::from_vec(vec![value; num_rows])) + } + (PrimitiveType::Binary, None) => { + let vals: Vec> = vec![None; num_rows]; + Arc::new(BinaryArray::from_opt_vec(vals)) + } + (dt, _) => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("unexpected target column type {}", dt), + )) + } + }) + } +} + +/// This accessor will cached the index path of the field for the schema to +/// speed up the access of the field in the schema for next time. +/// +/// # Limit of this accessor: +/// - It also means that this accessor must not used for multiple schema and the user should gurarantee that otherwise the it's unexpected behavior. +/// - The accessor will not search the key field of the map type or the value field of the map type. It will can search the map field itself. +/// - The accessor will not search the element field of the list type. It will only search the list field itself. +#[derive(Clone, Debug)] +struct CachedArrowArrayAccessor { + field_index_path_map: HashMap>, + default_value_generator: Option, +} + +impl CachedArrowArrayAccessor { + pub fn new( + iceberg_struct: &StructType, + arrow_fields: &Fields, + field_id_fetch: F, + default_value_generator: Option, + ) -> Result + where + F: Fn(&Field) -> Result, + { + let mut field_index_path_collector = FieldIndexPathCollector { + field_index_path_map: HashMap::new(), + }; + visit_struct_with_partner( + iceberg_struct, + &FieldPath::IndexPath(vec![], DataType::Struct(arrow_fields.clone())), + &mut field_index_path_collector, + &FieldIndexPathAccessor::new(field_id_fetch, default_value_generator.is_some()), + )?; + Ok(Self { + field_index_path_map: field_index_path_collector.field_index_path_map, + default_value_generator, + }) + } + + fn get_array_by_field_index_path( + arrays: &[ArrayRef], + field_index_path: &[usize], + ) -> Result { + let mut rev_iterator = field_index_path.iter().rev(); + let mut array = arrays[*rev_iterator.next().unwrap()].clone(); let mut null_buffer = array.logical_nulls(); for idx in rev_iterator { array = array .as_any() .downcast_ref::() - .ok_or(Error::new( - ErrorKind::Unexpected, - "Cannot convert Array to StructArray", - ))? + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Cannot convert Array to StructArray", + ) + })? .column(*idx) .clone(); null_buffer = NullBuffer::union(null_buffer.as_ref(), array.logical_nulls().as_ref()); @@ -158,61 +376,517 @@ impl RecordBatchProjector { } } +impl PartnerAccessor for CachedArrowArrayAccessor { + fn struct_parner(&self, schema_partner: &ArrayRef) -> Result { + if !matches!(schema_partner.data_type(), DataType::Struct(_)) { + return Err(Error::new( + ErrorKind::DataInvalid, + "The schema partner is not a struct type", + )); + } + Ok(schema_partner.clone()) + } + + fn field_partner(&self, struct_partner: &ArrayRef, field: &NestedField) -> Result { + let struct_array = struct_partner + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "The struct partner is not a struct array", + ) + })?; + // Find the field index path for the field, get the array by the index path. + if let Some(field_index_path) = self.field_index_path_map.get(&field.id) { + return Self::get_array_by_field_index_path(struct_array.columns(), field_index_path); + } + + // If the field not found, if it's a primitive field and the default value generator is set, + // use the default value generator to create the column. + let Some(default_value_generator) = &self.default_value_generator else { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Field {} not found", field.id), + )); + }; + let Some(target_type) = field.field_type.as_primitive_type() else { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Field {} not found", field.id), + )); + }; + let default_value = if let Some(default_value) = &field.initial_default { + let Literal::Primitive(primitive_literal) = default_value else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Default value for column must be primitive type, but encountered {:?}", + field.initial_default + ), + )); + }; + Some(primitive_literal.clone()) + } else { + None + }; + default_value_generator.create_column(target_type, &default_value, struct_array.len()) + } + + fn list_element_partner(&self, list_partner: &ArrayRef) -> Result { + match list_partner.data_type() { + DataType::List(_) => { + let list_array = list_partner + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "The list partner is not a list array", + ) + })?; + Ok(list_array.values().clone()) + } + DataType::LargeList(_) => { + let list_array = list_partner + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "The list partner is not a large list array", + ) + })?; + Ok(list_array.values().clone()) + } + DataType::FixedSizeList(_, _) => { + let list_array = list_partner + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "The list partner is not a fixed size list array", + ) + })?; + Ok(list_array.values().clone()) + } + _ => Err(Error::new( + ErrorKind::DataInvalid, + "The list partner is not a list type", + )), + } + } + + fn map_key_partner(&self, map_partner: &ArrayRef) -> Result { + let map_array = map_partner + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "The map partner is not a map array") + })?; + Ok(map_array.keys().clone()) + } + + fn map_value_partner(&self, map_partner: &ArrayRef) -> Result { + let map_array = map_partner + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "The map partner is not a map array") + })?; + Ok(map_array.values().clone()) + } +} + +#[derive(Clone, Debug)] +struct ArrowProjectVisitor; + +impl SchemaWithPartnerVisitor for ArrowProjectVisitor { + type T = ArrayRef; + + fn schema( + &mut self, + _schema: &crate::spec::Schema, + _partner: &ArrayRef, + value: ArrayRef, + ) -> Result { + Ok(value) + } + + fn field( + &mut self, + field: &crate::spec::NestedFieldRef, + partner: &ArrayRef, + value: ArrayRef, + ) -> Result { + let field_type = type_to_arrow_type(&field.field_type)?; + if !field_type.equals_datatype(value.data_type()) { + return Err( + Error::new(ErrorKind::Unexpected, "Field type is not matched") + .with_context("iceberg_field", format!("{:?}", field)) + .with_context("converted_arrow_type", format!("{:?}", value.data_type())) + .with_context("original_arrow_type", format!("{:?}", partner.data_type())), + ); + } + Ok(value) + } + + fn r#struct( + &mut self, + r#struct: &StructType, + partner: &ArrayRef, + results: Vec, + ) -> Result { + let DataType::Struct(new_arrow_struct_fields) = + type_to_arrow_type(&Type::Struct(r#struct.clone()))? + else { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a struct")); + }; + + // # TODO: Refine this code + // For struct array, it also crash for case that nulls is none but nulls of array is valid. + // This is a hack fix. Maybe we should fix at upstream later. + let nulls = if results.is_empty() { + None + } else { + Some( + partner + .logical_nulls() + .unwrap_or(NullBuffer::new_valid(partner.len())), + ) + }; + let new_struct_array = StructArray::new(new_arrow_struct_fields, results, nulls); + + Ok(Arc::new(new_struct_array)) + } + + fn list( + &mut self, + list: &crate::spec::ListType, + partner: &ArrayRef, + value: ArrayRef, + ) -> Result { + let nulls = partner.nulls().cloned(); + let new_element_type = type_to_arrow_type(&list.element_field.field_type)?; + match partner.data_type() { + DataType::List(field) => { + let original_list = partner.as_any().downcast_ref::().unwrap(); + let field = Arc::new(field.as_ref().clone().with_data_type(new_element_type)); + let offsets = original_list.offsets().clone(); + let list_array = ListArray::new(field, offsets, value, nulls); + Ok(Arc::new(list_array)) + } + DataType::LargeList(field) => { + let original_list = partner.as_any().downcast_ref::().unwrap(); + let field = Arc::new(field.as_ref().clone().with_data_type(new_element_type)); + let offsets = original_list.offsets().clone(); + let list_array = LargeListArray::new(field, offsets, value, nulls); + Ok(Arc::new(list_array)) + } + DataType::FixedSizeList(field, size) => { + let field = Arc::new(field.as_ref().clone().with_data_type(new_element_type)); + let list_array = FixedSizeListArray::new(field, *size, value, nulls); + Ok(Arc::new(list_array)) + } + _ => Err(Error::new(ErrorKind::Unexpected, "Field is not a list")), + } + } + + fn map( + &mut self, + map: &MapType, + partner: &ArrayRef, + key_value: ArrayRef, + value: ArrayRef, + ) -> Result { + let original_array = partner.as_any().downcast_ref::().unwrap(); + let offsets = original_array.offsets().clone(); + let nulls = original_array.nulls().cloned(); + + let DataType::Map(field, ordered) = type_to_arrow_type(&Type::Map(map.clone()))? else { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a map")); + }; + let DataType::Struct(inner_struct_fields) = field.data_type() else { + return Err(Error::new(ErrorKind::Unexpected, "Field is not a struct")); + }; + let entries = StructArray::new(inner_struct_fields.clone(), vec![key_value, value], None); + + Ok(Arc::new(MapArray::new( + field, offsets, entries, nulls, ordered, + ))) + } + + fn primitive(&mut self, ty: &PrimitiveType, partner: &ArrayRef) -> Result { + let target_type = type_to_arrow_type(&Type::Primitive(ty.clone()))?; + if target_type.equals_datatype(partner.data_type()) { + Ok(partner.clone()) + } else { + let res = cast(partner, &target_type)?; + Ok(res) + } + } +} + +/// It used to project record batch match iceberg schema +/// +/// This projector will handle the following schema evolution actions: +/// - Add new fields +/// - Type promotion +/// +/// The iceberg spec refers to other permissible schema evolution actions +/// (see https://iceberg.apache.org/spec/#schema-evolution): +/// renaming fields, deleting fields and reordering fields. +/// Renames only affect the schema of the RecordBatch rather than the +/// columns themselves, so a single updated cached schema can +/// be re-used and no per-column actions are required. +/// Deletion and Reorder can be achieved without needing this +/// post-processing step by using the projection mask. +#[derive(Clone, Debug)] +pub(crate) struct RecordBatchProjector { + iceberg_struct_type: StructType, + accessor: CachedArrowArrayAccessor, + visitor: ArrowProjectVisitor, +} + +impl RecordBatchProjector { + pub(crate) fn new( + expect_iceberg_schema: &Schema, + input_arrow_schema: &ArrowSchema, + field_id_fetch_func: F, + default_value_generator: Option, + ) -> Result + where + F: Fn(&Field) -> Result, + { + let accessor = CachedArrowArrayAccessor::new( + expect_iceberg_schema.as_struct(), + input_arrow_schema.fields(), + field_id_fetch_func, + default_value_generator, + )?; + Ok(Self { + iceberg_struct_type: expect_iceberg_schema.as_struct().clone(), + accessor, + visitor: ArrowProjectVisitor, + }) + } + + pub(crate) fn project_batch(&mut self, batch: RecordBatch) -> Result { + // Record the original row num of the batch. It used for select_empty case. + let row_num = batch.num_rows(); + + // Convert the batch to the iceberg schema into struct array. + let array = visit_struct_with_partner( + &self.iceberg_struct_type, + &(Arc::new(StructArray::from(batch)) as ArrayRef), + &mut self.visitor, + &self.accessor, + )? + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + + // Convert back to the arrow record batch. + let (fields, columns, nulls) = array.into_parts(); + if nulls.map(|n| n.null_count()).unwrap_or_default() != 0 { + return Err(Error::new( + ErrorKind::DataInvalid, + "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation", + )); + } + Ok(RecordBatch::try_new_with_options( + Arc::new(ArrowSchema::new(fields)), + columns, + &RecordBatchOptions::default() + .with_match_field_names(false) + .with_row_count(Some(row_num)), + )?) + } +} + #[cfg(test)] mod test { + use std::collections::HashMap; use std::sync::Arc; use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray}; - use arrow_schema::{DataType, Field, Fields, Schema}; + use arrow_schema::{DataType, Field}; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use crate::arrow::record_batch_projector::RecordBatchProjector; - use crate::{Error, ErrorKind}; + use crate::arrow::{get_field_id, schema_to_arrow_schema}; + use crate::spec::{ListType, MapType, NestedField, PrimitiveType, Schema, StructType, Type}; + + fn nested_schema_for_test() -> Schema { + // Int, Struct(Int,Int), String, List(Int), Struct(Struct(Int)), Map(String, List(Int)) + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(0, "col0", Type::Primitive(PrimitiveType::Long)).into(), + NestedField::required( + 1, + "col1", + Type::Struct(StructType::new(vec![ + NestedField::required(5, "col_1_5", Type::Primitive(PrimitiveType::Long)) + .into(), + NestedField::required(6, "col_1_6", Type::Primitive(PrimitiveType::Long)) + .into(), + ])), + ) + .into(), + NestedField::required(2, "col2", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required( + 3, + "col3", + Type::List(ListType::new( + NestedField::required(7, "element", Type::Primitive(PrimitiveType::Long)) + .into(), + )), + ) + .into(), + NestedField::required( + 4, + "col4", + Type::Struct(StructType::new(vec![NestedField::required( + 8, + "col_4_8", + Type::Struct(StructType::new(vec![NestedField::required( + 9, + "col_4_8_9", + Type::Primitive(PrimitiveType::Long), + ) + .into()])), + ) + .into()])), + ) + .into(), + NestedField::required( + 10, + "col5", + Type::Map(MapType::new( + NestedField::required(11, "key", Type::Primitive(PrimitiveType::String)) + .into(), + NestedField::required( + 12, + "value", + Type::List(ListType::new( + NestedField::required( + 13, + "item", + Type::Primitive(PrimitiveType::Long), + ) + .into(), + )), + ) + .into(), + )), + ) + .into(), + ]) + .build() + .unwrap() + } + + #[test] + fn test_fail_case_for_index_path_collect() { + let iceberg_schema = nested_schema_for_test(); + let arrow_schema = schema_to_arrow_schema(&iceberg_schema).unwrap(); + + // project map.key + let projected_schema = nested_schema_for_test().project(&[11]).unwrap(); + assert!( + RecordBatchProjector::new(&projected_schema, &arrow_schema, get_field_id, None) + .is_err() + ); + + // project map.value + let projected_schema = nested_schema_for_test().project(&[12]).unwrap(); + assert!( + RecordBatchProjector::new(&projected_schema, &arrow_schema, get_field_id, None) + .is_err() + ); + + // project list.element + let projected_schema = nested_schema_for_test().project(&[7]).unwrap(); + assert!( + RecordBatchProjector::new(&projected_schema, &arrow_schema, get_field_id, None) + .is_err() + ); + + // project map, list itself can success + let projected_schema = nested_schema_for_test().project(&[3, 10]).unwrap(); + RecordBatchProjector::new(&projected_schema, &arrow_schema, get_field_id, None).unwrap(); + } #[test] fn test_record_batch_projector_nested_level() { - let inner_fields = vec![ - Field::new("inner_field1", DataType::Int32, false), - Field::new("inner_field2", DataType::Utf8, false), - ]; - let fields = vec![ - Field::new("field1", DataType::Int32, false), - Field::new( - "field2", - DataType::Struct(Fields::from(inner_fields.clone())), - false, - ), - ]; - let schema = Arc::new(Schema::new(fields)); - - let field_id_fetch_func = |field: &Field| match field.name().as_str() { - "field1" => Ok(Some(1)), - "field2" => Ok(Some(2)), - "inner_field1" => Ok(Some(3)), - "inner_field2" => Ok(Some(4)), - _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")), - }; - let projector = - RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func, |_| true) + let iceberg_schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + Arc::new(NestedField::new( + 1, + "field1", + Type::Primitive(PrimitiveType::Int), + false, + )), + Arc::new(NestedField::new( + 2, + "field2", + Type::Struct(StructType::new(vec![ + Arc::new(NestedField::new( + 3, + "inner_field1", + Type::Primitive(PrimitiveType::Int), + false, + )), + Arc::new(NestedField::new( + 4, + "inner_field2", + Type::Primitive(PrimitiveType::String), + false, + )), + ])), + false, + )), + ]) + .build() + .unwrap(); + let arrow_schema = Arc::new(schema_to_arrow_schema(&iceberg_schema).unwrap()); + let projected_iceberg_schema = iceberg_schema.project(&[1, 3]).unwrap(); + let mut projector = + RecordBatchProjector::new(&projected_iceberg_schema, &arrow_schema, get_field_id, None) .unwrap(); - assert_eq!(projector.field_indices.len(), 2); - assert_eq!(projector.field_indices[0], vec![0]); - assert_eq!(projector.field_indices[1], vec![0, 1]); - let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let inner_int_array = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef; let inner_string_array = Arc::new(StringArray::from(vec!["x", "y", "z"])) as ArrayRef; let struct_array = Arc::new(StructArray::from(vec![ ( - Arc::new(inner_fields[0].clone()), + Arc::new( + Field::new("inner_field1", DataType::Int32, true).with_metadata( + HashMap::from_iter(vec![( + PARQUET_FIELD_ID_META_KEY.to_string(), + "3".to_string(), + )]), + ), + ), inner_int_array as ArrayRef, ), ( - Arc::new(inner_fields[1].clone()), + Arc::new( + Field::new("inner_field2", DataType::Utf8, true).with_metadata( + HashMap::from_iter(vec![( + PARQUET_FIELD_ID_META_KEY.to_string(), + "4".to_string(), + )]), + ), + ), inner_string_array as ArrayRef, ), ])) as ArrayRef; - let batch = RecordBatch::try_new(schema, vec![int_array, struct_array]).unwrap(); + let batch = RecordBatch::try_new(arrow_schema, vec![int_array, struct_array]).unwrap(); let projected_batch = projector.project_batch(batch).unwrap(); assert_eq!(projected_batch.num_columns(), 2); @@ -230,67 +904,4 @@ mod test { assert_eq!(projected_int_array.values(), &[1, 2, 3]); assert_eq!(projected_inner_int_array.values(), &[4, 5, 6]); } - - #[test] - fn test_field_not_found() { - let inner_fields = vec![ - Field::new("inner_field1", DataType::Int32, false), - Field::new("inner_field2", DataType::Utf8, false), - ]; - - let fields = vec![ - Field::new("field1", DataType::Int32, false), - Field::new( - "field2", - DataType::Struct(Fields::from(inner_fields.clone())), - false, - ), - ]; - let schema = Arc::new(Schema::new(fields)); - - let field_id_fetch_func = |field: &Field| match field.name().as_str() { - "field1" => Ok(Some(1)), - "field2" => Ok(Some(2)), - "inner_field1" => Ok(Some(3)), - "inner_field2" => Ok(Some(4)), - _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")), - }; - let projector = - RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func, |_| true); - - assert!(projector.is_err()); - } - - #[test] - fn test_field_not_reachable() { - let inner_fields = vec![ - Field::new("inner_field1", DataType::Int32, false), - Field::new("inner_field2", DataType::Utf8, false), - ]; - - let fields = vec![ - Field::new("field1", DataType::Int32, false), - Field::new( - "field2", - DataType::Struct(Fields::from(inner_fields.clone())), - false, - ), - ]; - let schema = Arc::new(Schema::new(fields)); - - let field_id_fetch_func = |field: &Field| match field.name().as_str() { - "field1" => Ok(Some(1)), - "field2" => Ok(Some(2)), - "inner_field1" => Ok(Some(3)), - "inner_field2" => Ok(Some(4)), - _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")), - }; - let projector = - RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| false); - assert!(projector.is_err()); - - let projector = - RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true); - assert!(projector.is_ok()); - } } diff --git a/crates/iceberg/src/arrow/record_batch_transformer.rs b/crates/iceberg/src/arrow/record_batch_transformer.rs index 38543509bb..6828141f2c 100644 --- a/crates/iceberg/src/arrow/record_batch_transformer.rs +++ b/crates/iceberg/src/arrow/record_batch_transformer.rs @@ -15,63 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::sync::Arc; -use arrow_array::{ - Array as ArrowArray, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, - Int32Array, Int64Array, NullArray, RecordBatch, RecordBatchOptions, StringArray, -}; -use arrow_cast::cast; -use arrow_schema::{ - DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, SchemaRef, -}; -use parquet::arrow::PARQUET_FIELD_ID_META_KEY; +use arrow_array::RecordBatch; +use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use super::get_field_id; +use super::record_batch_projector::{DefaultValueGenerator, RecordBatchProjector}; use crate::arrow::schema_to_arrow_schema; -use crate::spec::{Literal, PrimitiveLiteral, Schema as IcebergSchema}; -use crate::{Error, ErrorKind, Result}; - -/// Indicates how a particular column in a processed RecordBatch should -/// be sourced. -#[derive(Debug)] -pub(crate) enum ColumnSource { - // signifies that a column should be passed through unmodified - // from the file's RecordBatch - PassThrough { - source_index: usize, - }, - - // signifies that a column from the file's RecordBatch has undergone - // type promotion so the source column with the given index needs - // to be promoted to the specified type - Promote { - target_type: DataType, - source_index: usize, - }, - - // Signifies that a new column has been inserted before the column - // with index `index`. (we choose "before" rather than "after" so - // that we can use usize; if we insert after, then we need to - // be able to store -1 here to signify that a new - // column is to be added at the front of the column list). - // If multiple columns need to be inserted at a given - // location, they should all be given the same index, as the index - // here refers to the original RecordBatch, not the interim state after - // a preceding operation. - Add { - target_type: DataType, - value: Option, - }, - // The iceberg spec refers to other permissible schema evolution actions - // (see https://iceberg.apache.org/spec/#schema-evolution): - // renaming fields, deleting fields and reordering fields. - // Renames only affect the schema of the RecordBatch rather than the - // columns themselves, so a single updated cached schema can - // be re-used and no per-column actions are required. - // Deletion and Reorder can be achieved without needing this - // post-processing step by using the projection mask. -} +use crate::spec::Schema as IcebergSchema; +use crate::Result; #[derive(Debug)] enum BatchTransform { @@ -81,14 +34,7 @@ enum BatchTransform { PassThrough, Modify { - // Every transformed RecordBatch will have the same schema. We create the - // target just once and cache it here. Helpfully, Arc is needed in - // the constructor for RecordBatch, so we don't need an expensive copy - // each time we build a new RecordBatch - target_schema: Arc, - - // Indicates how each column in the target schema is derived. - operations: Vec, + record_bacth_projector: RecordBatchProjector, }, // Sometimes only the schema will need modifying, for example when @@ -137,21 +83,11 @@ impl RecordBatchTransformer { &mut self, record_batch: RecordBatch, ) -> Result { - Ok(match &self.batch_transform { + Ok(match &mut self.batch_transform { Some(BatchTransform::PassThrough) => record_batch, Some(BatchTransform::Modify { - target_schema, - operations, - }) => { - let options = RecordBatchOptions::default() - .with_match_field_names(false) - .with_row_count(Some(record_batch.num_rows())); - RecordBatch::try_new_with_options( - target_schema.clone(), - self.transform_columns(record_batch.columns(), operations)?, - &options, - )? - } + record_bacth_projector, + }) => record_bacth_projector.project_batch(record_batch)?, Some(BatchTransform::ModifySchema { target_schema }) => { record_batch.with_schema(target_schema.clone())? } @@ -179,36 +115,22 @@ impl RecordBatchTransformer { snapshot_schema: &IcebergSchema, projected_iceberg_field_ids: &[i32], ) -> Result { - let mapped_unprojected_arrow_schema = Arc::new(schema_to_arrow_schema(snapshot_schema)?); - let field_id_to_mapped_schema_map = - Self::build_field_id_to_arrow_schema_map(&mapped_unprojected_arrow_schema)?; - - // Create a new arrow schema by selecting fields from mapped_unprojected, - // in the order of the field ids in projected_iceberg_field_ids - let fields: Result> = projected_iceberg_field_ids - .iter() - .map(|field_id| { - Ok(field_id_to_mapped_schema_map - .get(field_id) - .ok_or(Error::new(ErrorKind::Unexpected, "field not found"))? - .0 - .clone()) - }) - .collect(); - - let target_schema = Arc::new(ArrowSchema::new(fields?)); + let projected_iceberg_schema = snapshot_schema.project(projected_iceberg_field_ids)?; + let target_schema = Arc::new(schema_to_arrow_schema(&projected_iceberg_schema)?); match Self::compare_schemas(source_schema, &target_schema) { SchemaComparison::Equivalent => Ok(BatchTransform::PassThrough), SchemaComparison::NameChangesOnly => Ok(BatchTransform::ModifySchema { target_schema }), SchemaComparison::Different => Ok(BatchTransform::Modify { - operations: Self::generate_transform_operations( - source_schema, - snapshot_schema, - projected_iceberg_field_ids, - field_id_to_mapped_schema_map, - )?, - target_schema, + record_bacth_projector: { + let projected_schema = snapshot_schema.project(projected_iceberg_field_ids)?; + RecordBatchProjector::new( + &projected_schema, + source_schema, + get_field_id, + Some(DefaultValueGenerator), + )? + }, }), } } @@ -257,187 +179,6 @@ impl RecordBatchTransformer { SchemaComparison::Equivalent } } - - fn generate_transform_operations( - source_schema: &ArrowSchemaRef, - snapshot_schema: &IcebergSchema, - projected_iceberg_field_ids: &[i32], - field_id_to_mapped_schema_map: HashMap, - ) -> Result> { - let field_id_to_source_schema_map = - Self::build_field_id_to_arrow_schema_map(source_schema)?; - - projected_iceberg_field_ids.iter().map(|field_id|{ - let (target_field, _) = field_id_to_mapped_schema_map.get(field_id).ok_or( - Error::new(ErrorKind::Unexpected, "could not find field in schema") - )?; - let target_type = target_field.data_type(); - - Ok(if let Some((source_field, source_index)) = field_id_to_source_schema_map.get(field_id) { - // column present in source - - if source_field.data_type().equals_datatype(target_type) { - // no promotion required - ColumnSource::PassThrough { - source_index: *source_index - } - } else { - // promotion required - ColumnSource::Promote { - target_type: target_type.clone(), - source_index: *source_index, - } - } - } else { - // column must be added - let iceberg_field = snapshot_schema.field_by_id(*field_id).ok_or( - Error::new(ErrorKind::Unexpected, "Field not found in snapshot schema") - )?; - - let default_value = if let Some(iceberg_default_value) = - &iceberg_field.initial_default - { - let Literal::Primitive(primitive_literal) = iceberg_default_value else { - return Err(Error::new( - ErrorKind::Unexpected, - format!("Default value for column must be primitive type, but encountered {:?}", iceberg_default_value) - )); - }; - Some(primitive_literal.clone()) - } else { - None - }; - - ColumnSource::Add { - value: default_value, - target_type: target_type.clone(), - } - }) - }).collect() - } - - fn build_field_id_to_arrow_schema_map( - source_schema: &SchemaRef, - ) -> Result> { - let mut field_id_to_source_schema = HashMap::new(); - for (source_field_idx, source_field) in source_schema.fields.iter().enumerate() { - let this_field_id = source_field - .metadata() - .get(PARQUET_FIELD_ID_META_KEY) - .ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - "field ID not present in parquet metadata", - ) - })? - .parse() - .map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - format!("field id not parseable as an i32: {}", e), - ) - })?; - - field_id_to_source_schema - .insert(this_field_id, (source_field.clone(), source_field_idx)); - } - - Ok(field_id_to_source_schema) - } - - fn transform_columns( - &self, - columns: &[Arc], - operations: &[ColumnSource], - ) -> Result>> { - if columns.is_empty() { - return Ok(columns.to_vec()); - } - let num_rows = columns[0].len(); - - operations - .iter() - .map(|op| { - Ok(match op { - ColumnSource::PassThrough { source_index } => columns[*source_index].clone(), - - ColumnSource::Promote { - target_type, - source_index, - } => cast(&*columns[*source_index], target_type)?, - - ColumnSource::Add { target_type, value } => { - Self::create_column(target_type, value, num_rows)? - } - }) - }) - .collect() - } - - fn create_column( - target_type: &DataType, - prim_lit: &Option, - num_rows: usize, - ) -> Result { - Ok(match (target_type, prim_lit) { - (DataType::Boolean, Some(PrimitiveLiteral::Boolean(value))) => { - Arc::new(BooleanArray::from(vec![*value; num_rows])) - } - (DataType::Boolean, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(BooleanArray::from(vals)) - } - (DataType::Int32, Some(PrimitiveLiteral::Int(value))) => { - Arc::new(Int32Array::from(vec![*value; num_rows])) - } - (DataType::Int32, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(Int32Array::from(vals)) - } - (DataType::Int64, Some(PrimitiveLiteral::Long(value))) => { - Arc::new(Int64Array::from(vec![*value; num_rows])) - } - (DataType::Int64, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(Int64Array::from(vals)) - } - (DataType::Float32, Some(PrimitiveLiteral::Float(value))) => { - Arc::new(Float32Array::from(vec![value.0; num_rows])) - } - (DataType::Float32, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(Float32Array::from(vals)) - } - (DataType::Float64, Some(PrimitiveLiteral::Double(value))) => { - Arc::new(Float64Array::from(vec![value.0; num_rows])) - } - (DataType::Float64, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(Float64Array::from(vals)) - } - (DataType::Utf8, Some(PrimitiveLiteral::String(value))) => { - Arc::new(StringArray::from(vec![value.clone(); num_rows])) - } - (DataType::Utf8, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(StringArray::from(vals)) - } - (DataType::Binary, Some(PrimitiveLiteral::Binary(value))) => { - Arc::new(BinaryArray::from_vec(vec![value; num_rows])) - } - (DataType::Binary, None) => { - let vals: Vec> = vec![None; num_rows]; - Arc::new(BinaryArray::from_opt_vec(vals)) - } - (DataType::Null, _) => Arc::new(NullArray::new(num_rows)), - (dt, _) => { - return Err(Error::new( - ErrorKind::Unexpected, - format!("unexpected target column type {}", dt), - )) - } - }) - } } #[cfg(test)] @@ -454,24 +195,6 @@ mod test { use crate::arrow::record_batch_transformer::RecordBatchTransformer; use crate::spec::{Literal, NestedField, PrimitiveType, Schema, Type}; - #[test] - fn build_field_id_to_source_schema_map_works() { - let arrow_schema = arrow_schema_already_same_as_target(); - - let result = - RecordBatchTransformer::build_field_id_to_arrow_schema_map(&arrow_schema).unwrap(); - - let expected = HashMap::from_iter([ - (10, (arrow_schema.fields()[0].clone(), 0)), - (11, (arrow_schema.fields()[1].clone(), 1)), - (12, (arrow_schema.fields()[2].clone(), 2)), - (14, (arrow_schema.fields()[3].clone(), 3)), - (15, (arrow_schema.fields()[4].clone(), 4)), - ]); - - assert!(result.eq(&expected)); - } - #[test] fn processor_returns_properly_shaped_record_batch_when_no_schema_migration_required() { let snapshot_schema = Arc::new(iceberg_table_schema()); diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index c0cd1a2213..0e8d2c2099 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -226,7 +226,7 @@ pub fn arrow_type_to_type(ty: &DataType) -> Result { const ARROW_FIELD_DOC_KEY: &str = "doc"; -pub(super) fn get_field_id(field: &Field) -> Result { +pub(crate) fn get_field_id(field: &Field) -> Result { if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) { return value.parse::().map_err(|e| { Error::new( diff --git a/crates/iceberg/src/arrow/value.rs b/crates/iceberg/src/arrow/value.rs index d78c4f4400..66142ad5ab 100644 --- a/crates/iceberg/src/arrow/value.rs +++ b/crates/iceberg/src/arrow/value.rs @@ -428,21 +428,17 @@ impl SchemaWithPartnerVisitor for ArrowArrayToIcebergStructConverter { struct ArrowArrayAccessor; impl PartnerAccessor for ArrowArrayAccessor { - fn struct_parner<'a>(&self, schema_partner: &'a ArrayRef) -> Result<&'a ArrayRef> { + fn struct_parner(&self, schema_partner: &ArrayRef) -> Result { if !matches!(schema_partner.data_type(), DataType::Struct(_)) { return Err(Error::new( ErrorKind::DataInvalid, "The schema partner is not a struct type", )); } - Ok(schema_partner) + Ok(schema_partner.clone()) } - fn field_partner<'a>( - &self, - struct_partner: &'a ArrayRef, - field: &NestedField, - ) -> Result<&'a ArrayRef> { + fn field_partner(&self, struct_partner: &ArrayRef, field: &NestedField) -> Result { let struct_array = struct_partner .as_any() .downcast_ref::() @@ -466,10 +462,10 @@ impl PartnerAccessor for ArrowArrayAccessor { format!("Field id {} not found in struct array", field.id), ) })?; - Ok(struct_array.column(field_pos)) + Ok(struct_array.column(field_pos).clone()) } - fn list_element_partner<'a>(&self, list_partner: &'a ArrayRef) -> Result<&'a ArrayRef> { + fn list_element_partner(&self, list_partner: &ArrayRef) -> Result { match list_partner.data_type() { DataType::List(_) => { let list_array = list_partner @@ -481,7 +477,7 @@ impl PartnerAccessor for ArrowArrayAccessor { "The list partner is not a list array", ) })?; - Ok(list_array.values()) + Ok(list_array.values().clone()) } DataType::LargeList(_) => { let list_array = list_partner @@ -493,7 +489,7 @@ impl PartnerAccessor for ArrowArrayAccessor { "The list partner is not a large list array", ) })?; - Ok(list_array.values()) + Ok(list_array.values().clone()) } DataType::FixedSizeList(_, _) => { let list_array = list_partner @@ -505,7 +501,7 @@ impl PartnerAccessor for ArrowArrayAccessor { "The list partner is not a fixed size list array", ) })?; - Ok(list_array.values()) + Ok(list_array.values().clone()) } _ => Err(Error::new( ErrorKind::DataInvalid, @@ -514,24 +510,24 @@ impl PartnerAccessor for ArrowArrayAccessor { } } - fn map_key_partner<'a>(&self, map_partner: &'a ArrayRef) -> Result<&'a ArrayRef> { + fn map_key_partner(&self, map_partner: &ArrayRef) -> Result { let map_array = map_partner .as_any() .downcast_ref::() .ok_or_else(|| { Error::new(ErrorKind::DataInvalid, "The map partner is not a map array") })?; - Ok(map_array.keys()) + Ok(map_array.keys().clone()) } - fn map_value_partner<'a>(&self, map_partner: &'a ArrayRef) -> Result<&'a ArrayRef> { + fn map_value_partner(&self, map_partner: &ArrayRef) -> Result { let map_array = map_partner .as_any() .downcast_ref::() .ok_or_else(|| { Error::new(ErrorKind::DataInvalid, "The map partner is not a map array") })?; - Ok(map_array.values()) + Ok(map_array.values().clone()) } } diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index 11c1d8190c..7397bdc225 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -266,19 +266,6 @@ impl<'a> TableScanBuilder<'a> { ) })?; - schema - .as_struct() - .field_by_id(field_id) - .ok_or_else(|| { - Error::new( - ErrorKind::FeatureUnsupported, - format!( - "Column {} is not a direct child of schema but a nested field, which is not supported now. Schema: {}", - column_name, schema - ), - ) - })?; - field_ids.push(field_id); } diff --git a/crates/iceberg/src/spec/schema/mod.rs b/crates/iceberg/src/spec/schema/mod.rs index b95244f42d..9d5fab803e 100644 --- a/crates/iceberg/src/spec/schema/mod.rs +++ b/crates/iceberg/src/spec/schema/mod.rs @@ -322,6 +322,36 @@ impl Schema { } } + /// Project the schema to a new schema with only the specified field ids. + pub fn project(&self, field_ids: &[i32]) -> Result { + let mut fields = vec![]; + let mut alias_to_id = BiHashMap::new(); + let mut identifier_field_ids = HashSet::new(); + + for field_id in field_ids { + let field = self.field_by_id(*field_id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field id {} does not exist in schema", field_id), + ) + })?; + fields.push(field.clone()); + if let Some(alias) = self.id_to_name.get(field_id) { + alias_to_id.insert(alias.clone(), *field_id); + } + if self.identifier_field_ids.contains(field_id) { + identifier_field_ids.insert(*field_id); + } + } + + Schema::builder() + .with_schema_id(self.schema_id) + .with_fields(fields) + .with_alias(alias_to_id) + .with_identifier_field_ids(identifier_field_ids) + .build() + } + /// Get field by field id. pub fn field_by_id(&self, field_id: i32) -> Option<&NestedFieldRef> { self.id_to_field.get(&field_id) diff --git a/crates/iceberg/src/spec/schema/visitor.rs b/crates/iceberg/src/spec/schema/visitor.rs index ebb9b86bba..06db1be070 100644 --- a/crates/iceberg/src/spec/schema/visitor.rs +++ b/crates/iceberg/src/spec/schema/visitor.rs @@ -190,15 +190,15 @@ pub trait SchemaWithPartnerVisitor

{ /// Accessor used to get child partner from parent partner. pub trait PartnerAccessor

{ /// Get the struct partner from schema partner. - fn struct_parner<'a>(&self, schema_partner: &'a P) -> Result<&'a P>; + fn struct_parner(&self, schema_partner: &P) -> Result

; /// Get the field partner from struct partner. - fn field_partner<'a>(&self, struct_partner: &'a P, field: &NestedField) -> Result<&'a P>; + fn field_partner(&self, struct_partner: &P, field: &NestedField) -> Result

; /// Get the list element partner from list partner. - fn list_element_partner<'a>(&self, list_partner: &'a P) -> Result<&'a P>; + fn list_element_partner(&self, list_partner: &P) -> Result

; /// Get the map key partner from map partner. - fn map_key_partner<'a>(&self, map_partner: &'a P) -> Result<&'a P>; + fn map_key_partner(&self, map_partner: &P) -> Result

; /// Get the map value partner from map partner. - fn map_value_partner<'a>(&self, map_partner: &'a P) -> Result<&'a P>; + fn map_value_partner(&self, map_partner: &P) -> Result

; } /// Visiting a type in post order. @@ -212,32 +212,36 @@ pub(crate) fn visit_type_with_partner, A: Part Type::Primitive(p) => visitor.primitive(p, partner), Type::List(list) => { let list_element_partner = accessor.list_element_partner(partner)?; - visitor.before_list_element(&list.element_field, list_element_partner)?; + visitor.before_list_element(&list.element_field, &list_element_partner)?; let element_results = visit_type_with_partner( &list.element_field.field_type, - list_element_partner, + &list_element_partner, visitor, accessor, )?; - visitor.after_list_element(&list.element_field, list_element_partner)?; + visitor.after_list_element(&list.element_field, &list_element_partner)?; visitor.list(list, partner, element_results) } Type::Map(map) => { let key_partner = accessor.map_key_partner(partner)?; - visitor.before_map_key(&map.key_field, key_partner)?; - let key_result = - visit_type_with_partner(&map.key_field.field_type, key_partner, visitor, accessor)?; - visitor.after_map_key(&map.key_field, key_partner)?; + visitor.before_map_key(&map.key_field, &key_partner)?; + let key_result = visit_type_with_partner( + &map.key_field.field_type, + &key_partner, + visitor, + accessor, + )?; + visitor.after_map_key(&map.key_field, &key_partner)?; let value_partner = accessor.map_value_partner(partner)?; - visitor.before_map_value(&map.value_field, value_partner)?; + visitor.before_map_value(&map.value_field, &value_partner)?; let value_result = visit_type_with_partner( &map.value_field.field_type, - value_partner, + &value_partner, visitor, accessor, )?; - visitor.after_map_value(&map.value_field, value_partner)?; + visitor.after_map_value(&map.value_field, &value_partner)?; visitor.map(map, partner, key_result, value_result) } @@ -255,10 +259,10 @@ pub fn visit_struct_with_partner, A: PartnerAc let mut results = Vec::with_capacity(s.fields().len()); for field in s.fields() { let field_partner = accessor.field_partner(partner, field)?; - visitor.before_struct_field(field, field_partner)?; - let result = visit_type_with_partner(&field.field_type, field_partner, visitor, accessor)?; - visitor.after_struct_field(field, field_partner)?; - let result = visitor.field(field, field_partner, result)?; + visitor.before_struct_field(field, &field_partner)?; + let result = visit_type_with_partner(&field.field_type, &field_partner, visitor, accessor)?; + visitor.after_struct_field(field, &field_partner)?; + let result = visitor.field(field, &field_partner, result)?; results.push(result); } @@ -274,7 +278,7 @@ pub fn visit_schema_with_partner, A: PartnerAc ) -> Result { let result = visit_struct_with_partner( &schema.r#struct, - accessor.struct_parner(partner)?, + &accessor.struct_parner(partner)?, visitor, accessor, )?; diff --git a/crates/iceberg/src/writer/base_writer/equality_delete_writer.rs b/crates/iceberg/src/writer/base_writer/equality_delete_writer.rs index fb9682573b..4004e60bf0 100644 --- a/crates/iceberg/src/writer/base_writer/equality_delete_writer.rs +++ b/crates/iceberg/src/writer/base_writer/equality_delete_writer.rs @@ -20,12 +20,11 @@ use std::sync::Arc; use arrow_array::RecordBatch; -use arrow_schema::{DataType, Field, SchemaRef as ArrowSchemaRef}; +use arrow_schema::SchemaRef as ArrowSchemaRef; use itertools::Itertools; -use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use crate::arrow::record_batch_projector::RecordBatchProjector; -use crate::arrow::schema_to_arrow_schema; +use crate::arrow::{get_field_id, schema_to_arrow_schema}; use crate::spec::{DataFile, SchemaRef, Struct}; use crate::writer::file_writer::{FileWriter, FileWriterBuilder}; use crate::writer::{IcebergWriter, IcebergWriterBuilder}; @@ -53,6 +52,7 @@ pub struct EqualityDeleteWriterConfig { // Projector used to project the data chunk into specific fields. projector: RecordBatchProjector, partition_value: Struct, + project_arrow_schema: ArrowSchemaRef, } impl EqualityDeleteWriterConfig { @@ -62,47 +62,40 @@ impl EqualityDeleteWriterConfig { original_schema: SchemaRef, partition_value: Option, ) -> Result { - let original_arrow_schema = Arc::new(schema_to_arrow_schema(&original_schema)?); + let projected_iceberg_schema = original_schema.project(&equality_ids)?; + // Check invalid field ids, The following rule comes from https://iceberg.apache.org/spec/#identifier-field-ids + // and https://iceberg.apache.org/spec/#equality-delete-files + // - The identifier field ids must be used for primitive types. + // - The identifier field ids must not be used for floating point types or nullable fields. + for field in projected_iceberg_schema.as_struct().fields() { + if !field.field_type.is_primitive() || field.field_type.is_floating_type() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Field {}(id: {}) is not allowed to be used for equality delete.", + field.name, field.id + ), + )); + } + } + let original_arrow_schema = schema_to_arrow_schema(&original_schema)?; let projector = RecordBatchProjector::new( - original_arrow_schema, - &equality_ids, - // The following rule comes from https://iceberg.apache.org/spec/#identifier-field-ids - // and https://iceberg.apache.org/spec/#equality-delete-files - // - The identifier field ids must be used for primitive types. - // - The identifier field ids must not be used for floating point types or nullable fields. - |field| { - // Only primitive type is allowed to be used for identifier field ids - if field.data_type().is_nested() - || matches!( - field.data_type(), - DataType::Float16 | DataType::Float32 | DataType::Float64 - ) - { - return Ok(None); - } - Ok(Some( - field - .metadata() - .get(PARQUET_FIELD_ID_META_KEY) - .ok_or_else(|| { - Error::new(ErrorKind::Unexpected, "Field metadata is missing.") - })? - .parse::() - .map_err(|e| Error::new(ErrorKind::Unexpected, e.to_string()))?, - )) - }, - |_field: &Field| true, + &projected_iceberg_schema, + &original_arrow_schema, + get_field_id, + None, )?; Ok(Self { equality_ids, projector, partition_value: partition_value.unwrap_or(Struct::empty()), + project_arrow_schema: Arc::new(schema_to_arrow_schema(&projected_iceberg_schema)?), }) } /// Return projected Schema pub fn projected_arrow_schema_ref(&self) -> &ArrowSchemaRef { - self.projector.projected_schema_ref() + &self.project_arrow_schema } } @@ -390,7 +383,7 @@ mod test { EqualityDeleteWriterConfig::new(equality_ids, Arc::new(schema), None).unwrap(); let delete_schema = arrow_schema_to_schema(equality_config.projected_arrow_schema_ref()).unwrap(); - let projector = equality_config.projector.clone(); + let mut projector = equality_config.projector.clone(); // prepare writer let pb = ParquetWriterBuilder::new( @@ -741,7 +734,7 @@ mod test { let equality_ids = vec![0_i32, 2, 5]; let equality_config = EqualityDeleteWriterConfig::new(equality_ids, Arc::new(schema), None).unwrap(); - let projector = equality_config.projector.clone(); + let mut projector = equality_config.projector.clone(); // check let to_write_projected = projector.project_batch(to_write)?; diff --git a/crates/integration_tests/tests/shared_tests/scan_all_type.rs b/crates/integration_tests/tests/shared_tests/scan_all_type.rs index 673a78ac03..088e6ac2f1 100644 --- a/crates/integration_tests/tests/shared_tests/scan_all_type.rs +++ b/crates/integration_tests/tests/shared_tests/scan_all_type.rs @@ -357,4 +357,35 @@ async fn test_scan_all_type() { let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); assert_eq!(batches.len(), 1); assert_eq!(batches[0], batch); + + // scan nested field + let batch_stream = table + .scan() + .select(vec!["struct.int", "struct.string"]) + .build() + .unwrap() + .to_arrow() + .await + .unwrap(); + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + let expect_batch: RecordBatch = { + let array = + StructArray::from(vec![ + ( + Arc::new(Field::new("int", DataType::Int32, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), 18.to_string())]), + )), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef, + ), + ( + Arc::new(Field::new("string", DataType::Utf8, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), 19.to_string())]), + )), + Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])) as ArrayRef, + ), + ]); + array.into() + }; + assert_eq!(batches.len(), 1); + assert_eq!(batches[0], expect_batch); }