Skip to content

Commit e6abbc5

Browse files
committed
refine RecordBatchProjector
1 parent ba8dca2 commit e6abbc5

File tree

2 files changed

+213
-105
lines changed

2 files changed

+213
-105
lines changed

crates/iceberg/src/arrow/record_batch_projector.rs

Lines changed: 92 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
2323
use crate::error::Result;
2424
use crate::{Error, ErrorKind};
2525

26-
/// Help to project specific field from `RecordBatch`` according to the fields id of meta of field.
26+
/// Help to project specific field from `RecordBatch`` according to the fields id.
2727
#[derive(Clone)]
28-
pub struct RecordBatchProjector {
28+
pub(crate) struct RecordBatchProjector {
2929
// A vector of vectors, where each inner vector represents the index path to access a specific field in a nested structure.
3030
// E.g. [[0], [1, 2]] means the first field is accessed directly from the first column,
3131
// while the second field is accessed from the second column and then from its third subcolumn (second column must be a struct column).
@@ -36,35 +36,35 @@ pub struct RecordBatchProjector {
3636

3737
impl RecordBatchProjector {
3838
/// Init ArrowFieldProjector
39-
pub fn new<F>(
39+
///
40+
/// This function will iterate through the field and fetch the field from the original schema according to the field ids.
41+
/// The function to fetch the field id from the field is provided by `field_id_fetch_func`, return None if the field need to be skipped.
42+
/// This function will iterate through the nested fields if the field is a struct, `searchable_field_func` can be used to control whether
43+
/// iterate into the nested fields.
44+
pub(crate) fn new<F1, F2>(
4045
original_schema: SchemaRef,
4146
field_ids: &[i32],
42-
field_id_fetch_func: F,
47+
field_id_fetch_func: F1,
48+
searchable_field_func: F2,
4349
) -> Result<Self>
4450
where
45-
F: Fn(&Field) -> Option<i64>,
51+
F1: Fn(&Field) -> Result<Option<i64>>,
52+
F2: Fn(&Field) -> bool,
4653
{
4754
let mut field_indices = Vec::with_capacity(field_ids.len());
4855
let mut fields = Vec::with_capacity(field_ids.len());
4956
for &id in field_ids {
5057
let mut field_index = vec![];
51-
if let Ok(field) = Self::fetch_field_index(
58+
let field = Self::fetch_field_index(
5259
original_schema.fields(),
5360
&mut field_index,
5461
id as i64,
5562
&field_id_fetch_func,
56-
) {
57-
fields.push(field.clone());
58-
field_indices.push(field_index);
59-
} else {
60-
return Err(Error::new(
61-
ErrorKind::DataInvalid,
62-
format!(
63-
"Can't find source column id or column data type invalid: {}",
64-
id
65-
),
66-
));
67-
}
63+
&searchable_field_func,
64+
)?
65+
.ok_or_else(|| Error::new(ErrorKind::Unexpected, "Field not found"))?;
66+
fields.push(field.clone());
67+
field_indices.push(field_index);
6868
}
6969
let delete_arrow_schema = Arc::new(Schema::new(fields));
7070
Ok(Self {
@@ -73,59 +73,50 @@ impl RecordBatchProjector {
7373
})
7474
}
7575

76-
fn fetch_field_index<F>(
76+
fn fetch_field_index<F1, F2>(
7777
fields: &Fields,
7878
index_vec: &mut Vec<usize>,
7979
target_field_id: i64,
80-
field_id_fetch_func: &F,
81-
) -> Result<FieldRef>
80+
field_id_fetch_func: &F1,
81+
searchable_field_func: &F2,
82+
) -> Result<Option<FieldRef>>
8283
where
83-
F: Fn(&Field) -> Option<i64>,
84+
F1: Fn(&Field) -> Result<Option<i64>>,
85+
F2: Fn(&Field) -> bool,
8486
{
8587
for (pos, field) in fields.iter().enumerate() {
86-
match field.data_type() {
87-
DataType::Float16 | DataType::Float32 | DataType::Float64 => {
88-
return Err(Error::new(
89-
ErrorKind::DataInvalid,
90-
"Delete column data type cannot be float or double",
91-
));
88+
let id = field_id_fetch_func(field)?;
89+
if let Some(id) = id {
90+
if target_field_id == id {
91+
index_vec.push(pos);
92+
return Ok(Some(field.clone()));
9293
}
93-
_ => {
94-
let id = field_id_fetch_func(field).ok_or_else(|| {
95-
Error::new(ErrorKind::DataInvalid, "column_id must be parsable as i64")
96-
})?;
97-
if target_field_id == id {
94+
}
95+
if let DataType::Struct(inner) = field.data_type() {
96+
if searchable_field_func(field) {
97+
if let Some(res) = Self::fetch_field_index(
98+
inner,
99+
index_vec,
100+
target_field_id,
101+
field_id_fetch_func,
102+
searchable_field_func,
103+
)? {
98104
index_vec.push(pos);
99-
return Ok(field.clone());
100-
}
101-
if let DataType::Struct(inner) = field.data_type() {
102-
let res = Self::fetch_field_index(
103-
inner,
104-
index_vec,
105-
target_field_id,
106-
field_id_fetch_func,
107-
);
108-
if !index_vec.is_empty() {
109-
index_vec.push(pos);
110-
return res;
111-
}
105+
return Ok(Some(res));
112106
}
113107
}
114108
}
115109
}
116-
Err(Error::new(
117-
ErrorKind::DataInvalid,
118-
"Column id not found in fields",
119-
))
110+
Ok(None)
120111
}
121112

122113
/// Return the reference of projected schema
123-
pub fn projected_schema_ref(&self) -> &SchemaRef {
114+
pub(crate) fn projected_schema_ref(&self) -> &SchemaRef {
124115
&self.projected_schema
125116
}
126117

127118
/// Do projection with record batch
128-
pub fn project_bacth(&self, batch: RecordBatch) -> Result<RecordBatch> {
119+
pub(crate) fn project_bacth(&self, batch: RecordBatch) -> Result<RecordBatch> {
129120
RecordBatch::try_new(
130121
self.projected_schema.clone(),
131122
self.project_column(batch.columns())?,
@@ -134,7 +125,7 @@ impl RecordBatchProjector {
134125
}
135126

136127
/// Do projection with columns
137-
pub fn project_column(&self, batch: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
128+
pub(crate) fn project_column(&self, batch: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
138129
self.field_indices
139130
.iter()
140131
.map(|index_vec| Self::get_column_by_field_index(batch, index_vec))
@@ -167,6 +158,7 @@ mod test {
167158
use arrow_schema::{DataType, Field, Fields, Schema};
168159

169160
use crate::arrow::record_batch_projector::RecordBatchProjector;
161+
use crate::{Error, ErrorKind};
170162

171163
#[test]
172164
fn test_record_batch_projector_nested_level() {
@@ -185,14 +177,15 @@ mod test {
185177
let schema = Arc::new(Schema::new(fields));
186178

187179
let field_id_fetch_func = |field: &Field| match field.name().as_str() {
188-
"field1" => Some(1),
189-
"field2" => Some(2),
190-
"inner_field1" => Some(3),
191-
"inner_field2" => Some(4),
192-
_ => None,
180+
"field1" => Ok(Some(1)),
181+
"field2" => Ok(Some(2)),
182+
"inner_field1" => Ok(Some(3)),
183+
"inner_field2" => Ok(Some(4)),
184+
_ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
193185
};
194186
let projector =
195-
RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func).unwrap();
187+
RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func, |_| true)
188+
.unwrap();
196189

197190
assert!(projector.field_indices.len() == 2);
198191
assert_eq!(projector.field_indices[0], vec![0]);
@@ -248,14 +241,48 @@ mod test {
248241
let schema = Arc::new(Schema::new(fields));
249242

250243
let field_id_fetch_func = |field: &Field| match field.name().as_str() {
251-
"field1" => Some(1),
252-
"field2" => Some(2),
253-
"inner_field1" => Some(3),
254-
"inner_field2" => Some(4),
255-
_ => None,
244+
"field1" => Ok(Some(1)),
245+
"field2" => Ok(Some(2)),
246+
"inner_field1" => Ok(Some(3)),
247+
"inner_field2" => Ok(Some(4)),
248+
_ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
256249
};
257-
let projector = RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func);
250+
let projector =
251+
RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func, |_| true);
252+
253+
assert!(projector.is_err());
254+
}
255+
256+
#[test]
257+
fn test_field_not_reachable() {
258+
let inner_fields = vec![
259+
Field::new("inner_field1", DataType::Int32, false),
260+
Field::new("inner_field2", DataType::Utf8, false),
261+
];
262+
263+
let fields = vec![
264+
Field::new("field1", DataType::Int32, false),
265+
Field::new(
266+
"field2",
267+
DataType::Struct(Fields::from(inner_fields.clone())),
268+
false,
269+
),
270+
];
271+
let schema = Arc::new(Schema::new(fields));
258272

273+
let field_id_fetch_func = |field: &Field| match field.name().as_str() {
274+
"field1" => Ok(Some(1)),
275+
"field2" => Ok(Some(2)),
276+
"inner_field1" => Ok(Some(3)),
277+
"inner_field2" => Ok(Some(4)),
278+
_ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
279+
};
280+
let projector =
281+
RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| false);
259282
assert!(projector.is_err());
283+
284+
let projector =
285+
RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true);
286+
assert!(projector.is_ok());
260287
}
261288
}

0 commit comments

Comments
 (0)