@@ -23,9 +23,9 @@ use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
2323use crate :: error:: Result ;
2424use crate :: { Error , ErrorKind } ;
2525
26- /// Help to project specific field from `RecordBatch`` according to the fields id of meta of field .
26+ /// Help to project specific field from `RecordBatch`` according to the fields id.
2727#[ derive( Clone ) ]
28- pub struct RecordBatchProjector {
28+ pub ( crate ) struct RecordBatchProjector {
2929 // A vector of vectors, where each inner vector represents the index path to access a specific field in a nested structure.
3030 // E.g. [[0], [1, 2]] means the first field is accessed directly from the first column,
3131 // while the second field is accessed from the second column and then from its third subcolumn (second column must be a struct column).
@@ -36,35 +36,35 @@ pub struct RecordBatchProjector {
3636
3737impl RecordBatchProjector {
3838 /// Init ArrowFieldProjector
39- pub fn new < F > (
39+ ///
40+ /// This function will iterate through the field and fetch the field from the original schema according to the field ids.
41+ /// The function to fetch the field id from the field is provided by `field_id_fetch_func`, return None if the field need to be skipped.
42+ /// This function will iterate through the nested fields if the field is a struct, `searchable_field_func` can be used to control whether
43+ /// iterate into the nested fields.
44+ pub ( crate ) fn new < F1 , F2 > (
4045 original_schema : SchemaRef ,
4146 field_ids : & [ i32 ] ,
42- field_id_fetch_func : F ,
47+ field_id_fetch_func : F1 ,
48+ searchable_field_func : F2 ,
4349 ) -> Result < Self >
4450 where
45- F : Fn ( & Field ) -> Option < i64 > ,
51+ F1 : Fn ( & Field ) -> Result < Option < i64 > > ,
52+ F2 : Fn ( & Field ) -> bool ,
4653 {
4754 let mut field_indices = Vec :: with_capacity ( field_ids. len ( ) ) ;
4855 let mut fields = Vec :: with_capacity ( field_ids. len ( ) ) ;
4956 for & id in field_ids {
5057 let mut field_index = vec ! [ ] ;
51- if let Ok ( field) = Self :: fetch_field_index (
58+ let field = Self :: fetch_field_index (
5259 original_schema. fields ( ) ,
5360 & mut field_index,
5461 id as i64 ,
5562 & field_id_fetch_func,
56- ) {
57- fields. push ( field. clone ( ) ) ;
58- field_indices. push ( field_index) ;
59- } else {
60- return Err ( Error :: new (
61- ErrorKind :: DataInvalid ,
62- format ! (
63- "Can't find source column id or column data type invalid: {}" ,
64- id
65- ) ,
66- ) ) ;
67- }
63+ & searchable_field_func,
64+ ) ?
65+ . ok_or_else ( || Error :: new ( ErrorKind :: Unexpected , "Field not found" ) ) ?;
66+ fields. push ( field. clone ( ) ) ;
67+ field_indices. push ( field_index) ;
6868 }
6969 let delete_arrow_schema = Arc :: new ( Schema :: new ( fields) ) ;
7070 Ok ( Self {
@@ -73,59 +73,50 @@ impl RecordBatchProjector {
7373 } )
7474 }
7575
76- fn fetch_field_index < F > (
76+ fn fetch_field_index < F1 , F2 > (
7777 fields : & Fields ,
7878 index_vec : & mut Vec < usize > ,
7979 target_field_id : i64 ,
80- field_id_fetch_func : & F ,
81- ) -> Result < FieldRef >
80+ field_id_fetch_func : & F1 ,
81+ searchable_field_func : & F2 ,
82+ ) -> Result < Option < FieldRef > >
8283 where
83- F : Fn ( & Field ) -> Option < i64 > ,
84+ F1 : Fn ( & Field ) -> Result < Option < i64 > > ,
85+ F2 : Fn ( & Field ) -> bool ,
8486 {
8587 for ( pos, field) in fields. iter ( ) . enumerate ( ) {
86- match field. data_type ( ) {
87- DataType :: Float16 | DataType :: Float32 | DataType :: Float64 => {
88- return Err ( Error :: new (
89- ErrorKind :: DataInvalid ,
90- "Delete column data type cannot be float or double" ,
91- ) ) ;
88+ let id = field_id_fetch_func ( field) ?;
89+ if let Some ( id) = id {
90+ if target_field_id == id {
91+ index_vec. push ( pos) ;
92+ return Ok ( Some ( field. clone ( ) ) ) ;
9293 }
93- _ => {
94- let id = field_id_fetch_func ( field) . ok_or_else ( || {
95- Error :: new ( ErrorKind :: DataInvalid , "column_id must be parsable as i64" )
96- } ) ?;
97- if target_field_id == id {
94+ }
95+ if let DataType :: Struct ( inner) = field. data_type ( ) {
96+ if searchable_field_func ( field) {
97+ if let Some ( res) = Self :: fetch_field_index (
98+ inner,
99+ index_vec,
100+ target_field_id,
101+ field_id_fetch_func,
102+ searchable_field_func,
103+ ) ? {
98104 index_vec. push ( pos) ;
99- return Ok ( field. clone ( ) ) ;
100- }
101- if let DataType :: Struct ( inner) = field. data_type ( ) {
102- let res = Self :: fetch_field_index (
103- inner,
104- index_vec,
105- target_field_id,
106- field_id_fetch_func,
107- ) ;
108- if !index_vec. is_empty ( ) {
109- index_vec. push ( pos) ;
110- return res;
111- }
105+ return Ok ( Some ( res) ) ;
112106 }
113107 }
114108 }
115109 }
116- Err ( Error :: new (
117- ErrorKind :: DataInvalid ,
118- "Column id not found in fields" ,
119- ) )
110+ Ok ( None )
120111 }
121112
122113 /// Return the reference of projected schema
123- pub fn projected_schema_ref ( & self ) -> & SchemaRef {
114+ pub ( crate ) fn projected_schema_ref ( & self ) -> & SchemaRef {
124115 & self . projected_schema
125116 }
126117
127118 /// Do projection with record batch
128- pub fn project_bacth ( & self , batch : RecordBatch ) -> Result < RecordBatch > {
119+ pub ( crate ) fn project_bacth ( & self , batch : RecordBatch ) -> Result < RecordBatch > {
129120 RecordBatch :: try_new (
130121 self . projected_schema . clone ( ) ,
131122 self . project_column ( batch. columns ( ) ) ?,
@@ -134,7 +125,7 @@ impl RecordBatchProjector {
134125 }
135126
136127 /// Do projection with columns
137- pub fn project_column ( & self , batch : & [ ArrayRef ] ) -> Result < Vec < ArrayRef > > {
128+ pub ( crate ) fn project_column ( & self , batch : & [ ArrayRef ] ) -> Result < Vec < ArrayRef > > {
138129 self . field_indices
139130 . iter ( )
140131 . map ( |index_vec| Self :: get_column_by_field_index ( batch, index_vec) )
@@ -167,6 +158,7 @@ mod test {
167158 use arrow_schema:: { DataType , Field , Fields , Schema } ;
168159
169160 use crate :: arrow:: record_batch_projector:: RecordBatchProjector ;
161+ use crate :: { Error , ErrorKind } ;
170162
171163 #[ test]
172164 fn test_record_batch_projector_nested_level ( ) {
@@ -185,14 +177,15 @@ mod test {
185177 let schema = Arc :: new ( Schema :: new ( fields) ) ;
186178
187179 let field_id_fetch_func = |field : & Field | match field. name ( ) . as_str ( ) {
188- "field1" => Some ( 1 ) ,
189- "field2" => Some ( 2 ) ,
190- "inner_field1" => Some ( 3 ) ,
191- "inner_field2" => Some ( 4 ) ,
192- _ => None ,
180+ "field1" => Ok ( Some ( 1 ) ) ,
181+ "field2" => Ok ( Some ( 2 ) ) ,
182+ "inner_field1" => Ok ( Some ( 3 ) ) ,
183+ "inner_field2" => Ok ( Some ( 4 ) ) ,
184+ _ => Err ( Error :: new ( ErrorKind :: Unexpected , "Field id not found" ) ) ,
193185 } ;
194186 let projector =
195- RecordBatchProjector :: new ( schema. clone ( ) , & [ 1 , 3 ] , field_id_fetch_func) . unwrap ( ) ;
187+ RecordBatchProjector :: new ( schema. clone ( ) , & [ 1 , 3 ] , field_id_fetch_func, |_| true )
188+ . unwrap ( ) ;
196189
197190 assert ! ( projector. field_indices. len( ) == 2 ) ;
198191 assert_eq ! ( projector. field_indices[ 0 ] , vec![ 0 ] ) ;
@@ -248,14 +241,48 @@ mod test {
248241 let schema = Arc :: new ( Schema :: new ( fields) ) ;
249242
250243 let field_id_fetch_func = |field : & Field | match field. name ( ) . as_str ( ) {
251- "field1" => Some ( 1 ) ,
252- "field2" => Some ( 2 ) ,
253- "inner_field1" => Some ( 3 ) ,
254- "inner_field2" => Some ( 4 ) ,
255- _ => None ,
244+ "field1" => Ok ( Some ( 1 ) ) ,
245+ "field2" => Ok ( Some ( 2 ) ) ,
246+ "inner_field1" => Ok ( Some ( 3 ) ) ,
247+ "inner_field2" => Ok ( Some ( 4 ) ) ,
248+ _ => Err ( Error :: new ( ErrorKind :: Unexpected , "Field id not found" ) ) ,
256249 } ;
257- let projector = RecordBatchProjector :: new ( schema. clone ( ) , & [ 1 , 5 ] , field_id_fetch_func) ;
250+ let projector =
251+ RecordBatchProjector :: new ( schema. clone ( ) , & [ 1 , 5 ] , field_id_fetch_func, |_| true ) ;
252+
253+ assert ! ( projector. is_err( ) ) ;
254+ }
255+
256+ #[ test]
257+ fn test_field_not_reachable ( ) {
258+ let inner_fields = vec ! [
259+ Field :: new( "inner_field1" , DataType :: Int32 , false ) ,
260+ Field :: new( "inner_field2" , DataType :: Utf8 , false ) ,
261+ ] ;
262+
263+ let fields = vec ! [
264+ Field :: new( "field1" , DataType :: Int32 , false ) ,
265+ Field :: new(
266+ "field2" ,
267+ DataType :: Struct ( Fields :: from( inner_fields. clone( ) ) ) ,
268+ false ,
269+ ) ,
270+ ] ;
271+ let schema = Arc :: new ( Schema :: new ( fields) ) ;
258272
273+ let field_id_fetch_func = |field : & Field | match field. name ( ) . as_str ( ) {
274+ "field1" => Ok ( Some ( 1 ) ) ,
275+ "field2" => Ok ( Some ( 2 ) ) ,
276+ "inner_field1" => Ok ( Some ( 3 ) ) ,
277+ "inner_field2" => Ok ( Some ( 4 ) ) ,
278+ _ => Err ( Error :: new ( ErrorKind :: Unexpected , "Field id not found" ) ) ,
279+ } ;
280+ let projector =
281+ RecordBatchProjector :: new ( schema. clone ( ) , & [ 3 ] , field_id_fetch_func, |_| false ) ;
259282 assert ! ( projector. is_err( ) ) ;
283+
284+ let projector =
285+ RecordBatchProjector :: new ( schema. clone ( ) , & [ 3 ] , field_id_fetch_func, |_| true ) ;
286+ assert ! ( projector. is_ok( ) ) ;
260287 }
261288}
0 commit comments