11use std:: collections:: hash_map:: Entry ;
22use std:: collections:: HashMap ;
3+ use std:: sync:: Arc ;
34
4- use arrow_array:: { ArrayRef , Float32Array , Float64Array , RecordBatch } ;
5+ use arrow_array:: { ArrayRef , Float32Array , Float64Array , RecordBatch , StructArray } ;
6+ use arrow_schema:: DataType ;
57
68use crate :: arrow:: ArrowArrayAccessor ;
79use crate :: spec:: {
8- visit_schema_with_partner , ListType , MapType , NestedFieldRef , SchemaRef ,
9- PrimitiveType , Schema , SchemaWithPartnerVisitor , StructType ,
10+ visit_struct_with_partner , ListType , MapType , NestedFieldRef , PrimitiveType , Schema , SchemaRef ,
11+ SchemaWithPartnerVisitor , StructType ,
1012} ;
1113use crate :: Result ;
1214
13- macro_rules! count_float_nans {
15+ macro_rules! cast_and_update_cnt_map {
1416 ( $t: ty, $col: ident, $self: ident, $field_id: ident) => {
1517 let nan_val_cnt = $col
1618 . as_any( )
@@ -29,10 +31,24 @@ macro_rules! count_float_nans {
2931 v. insert( nan_val_cnt) ;
3032 }
3133 } ;
34+ }
35+ }
36+
37+ macro_rules! count_float_nans {
38+ ( $col: ident, $self: ident, $field_id: ident) => {
39+ match $col. data_type( ) {
40+ DataType :: Float32 => {
41+ cast_and_update_cnt_map!( Float32Array , $col, $self, $field_id) ;
42+ }
43+ DataType :: Float64 => {
44+ cast_and_update_cnt_map!( Float64Array , $col, $self, $field_id) ;
45+ }
46+ _ => { }
47+ }
3248 } ;
3349}
3450
35- /// TODO(feniljain )
51+ /// Visitor which counts and keeps track of NaN value counts in given record batch(s )
3652pub struct NanValueCountVisitor {
3753 /// Stores field ID to NaN value count mapping
3854 pub nan_value_counts : HashMap < i32 , u64 > ,
@@ -82,21 +98,31 @@ impl SchemaWithPartnerVisitor<ArrayRef> for NanValueCountVisitor {
8298 Ok ( ( ) )
8399 }
84100
85- fn primitive ( & mut self , p : & PrimitiveType , col : & ArrayRef ) -> Result < Self :: T > {
86- match p {
87- PrimitiveType :: Float => {
88- // let field_id = p.id;
89- // TODO(feniljain): fix this
90- let field_id = 1 ;
91- count_float_nans ! ( Float32Array , col, self , field_id) ;
92- }
93- PrimitiveType :: Double => {
94- let field_id = 1 ;
95- count_float_nans ! ( Float64Array , col, self , field_id) ;
96- }
97- _ => { }
98- }
101+ fn primitive ( & mut self , _p : & PrimitiveType , _col : & ArrayRef ) -> Result < Self :: T > {
102+ Ok ( ( ) )
103+ }
99104
105+ fn after_struct_field ( & mut self , field : & NestedFieldRef , partner : & ArrayRef ) -> Result < ( ) > {
106+ let field_id = field. id ;
107+ count_float_nans ! ( partner, self , field_id) ;
108+ Ok ( ( ) )
109+ }
110+
111+ fn after_list_element ( & mut self , field : & NestedFieldRef , partner : & ArrayRef ) -> Result < ( ) > {
112+ let field_id = field. id ;
113+ count_float_nans ! ( partner, self , field_id) ;
114+ Ok ( ( ) )
115+ }
116+
117+ fn after_map_key ( & mut self , field : & NestedFieldRef , partner : & ArrayRef ) -> Result < ( ) > {
118+ let field_id = field. id ;
119+ count_float_nans ! ( partner, self , field_id) ;
120+ Ok ( ( ) )
121+ }
122+
123+ fn after_map_value ( & mut self , field : & NestedFieldRef , partner : & ArrayRef ) -> Result < ( ) > {
124+ let field_id = field. id ;
125+ count_float_nans ! ( partner, self , field_id) ;
100126 Ok ( ( ) )
101127 }
102128}
@@ -110,14 +136,17 @@ impl NanValueCountVisitor {
110136 }
111137
112138 /// Compute nan value counts in given schema and record batch
113- pub fn compute ( & mut self , schema : SchemaRef , batch : & RecordBatch ) -> Result < ( ) > {
114- let arrow_arr_partner_accessor = ArrowArrayAccessor { } ;
139+ pub fn compute ( & mut self , schema : SchemaRef , batch : RecordBatch ) -> Result < ( ) > {
140+ let arrow_arr_partner_accessor = ArrowArrayAccessor { } ;
115141
116- for arr_ref in batch. columns ( ) {
117- visit_schema_with_partner ( & schema, arr_ref, self , & arrow_arr_partner_accessor) ?;
118- }
142+ let struct_arr = Arc :: new ( StructArray :: from ( batch) ) as ArrayRef ;
143+ visit_struct_with_partner (
144+ & schema. as_struct ( ) ,
145+ & struct_arr,
146+ self ,
147+ & arrow_arr_partner_accessor,
148+ ) ?;
119149
120150 Ok ( ( ) )
121151 }
122152}
123-
0 commit comments