Skip to content

Commit d87cf00

Browse files
committed
fix: correct usage of visitor
1 parent 1a0e5eb commit d87cf00

File tree

2 files changed

+58
-27
lines changed

2 files changed

+58
-27
lines changed

crates/iceberg/src/arrow/nan_val_cnt_visitor.rs

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
use std::collections::hash_map::Entry;
22
use std::collections::HashMap;
3+
use std::sync::Arc;
34

4-
use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch};
5+
use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, StructArray};
6+
use arrow_schema::DataType;
57

68
use crate::arrow::ArrowArrayAccessor;
79
use crate::spec::{
8-
visit_schema_with_partner, ListType, MapType, NestedFieldRef, SchemaRef,
9-
PrimitiveType, Schema, SchemaWithPartnerVisitor, StructType,
10+
visit_struct_with_partner, ListType, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef,
11+
SchemaWithPartnerVisitor, StructType,
1012
};
1113
use crate::Result;
1214

13-
macro_rules! count_float_nans {
15+
macro_rules! cast_and_update_cnt_map {
1416
($t:ty, $col:ident, $self:ident, $field_id:ident) => {
1517
let nan_val_cnt = $col
1618
.as_any()
@@ -29,10 +31,24 @@ macro_rules! count_float_nans {
2931
v.insert(nan_val_cnt);
3032
}
3133
};
34+
}
35+
}
36+
37+
macro_rules! count_float_nans {
38+
($col:ident, $self:ident, $field_id:ident) => {
39+
match $col.data_type() {
40+
DataType::Float32 => {
41+
cast_and_update_cnt_map!(Float32Array, $col, $self, $field_id);
42+
}
43+
DataType::Float64 => {
44+
cast_and_update_cnt_map!(Float64Array, $col, $self, $field_id);
45+
}
46+
_ => {}
47+
}
3248
};
3349
}
3450

35-
/// TODO(feniljain)
51+
/// Visitor which counts and keeps track of NaN value counts in given record batch(s)
3652
pub struct NanValueCountVisitor {
3753
/// Stores field ID to NaN value count mapping
3854
pub nan_value_counts: HashMap<i32, u64>,
@@ -82,21 +98,31 @@ impl SchemaWithPartnerVisitor<ArrayRef> for NanValueCountVisitor {
8298
Ok(())
8399
}
84100

85-
fn primitive(&mut self, p: &PrimitiveType, col: &ArrayRef) -> Result<Self::T> {
86-
match p {
87-
PrimitiveType::Float => {
88-
// let field_id = p.id;
89-
// TODO(feniljain): fix this
90-
let field_id = 1;
91-
count_float_nans!(Float32Array, col, self, field_id);
92-
}
93-
PrimitiveType::Double => {
94-
let field_id = 1;
95-
count_float_nans!(Float64Array, col, self, field_id);
96-
}
97-
_ => {}
98-
}
101+
fn primitive(&mut self, _p: &PrimitiveType, _col: &ArrayRef) -> Result<Self::T> {
102+
Ok(())
103+
}
99104

105+
fn after_struct_field(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
106+
let field_id = field.id;
107+
count_float_nans!(partner, self, field_id);
108+
Ok(())
109+
}
110+
111+
fn after_list_element(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
112+
let field_id = field.id;
113+
count_float_nans!(partner, self, field_id);
114+
Ok(())
115+
}
116+
117+
fn after_map_key(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
118+
let field_id = field.id;
119+
count_float_nans!(partner, self, field_id);
120+
Ok(())
121+
}
122+
123+
fn after_map_value(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
124+
let field_id = field.id;
125+
count_float_nans!(partner, self, field_id);
100126
Ok(())
101127
}
102128
}
@@ -110,14 +136,17 @@ impl NanValueCountVisitor {
110136
}
111137

112138
/// Compute nan value counts in given schema and record batch
113-
pub fn compute(&mut self, schema: SchemaRef, batch: &RecordBatch) -> Result<()> {
114-
let arrow_arr_partner_accessor = ArrowArrayAccessor{};
139+
pub fn compute(&mut self, schema: SchemaRef, batch: RecordBatch) -> Result<()> {
140+
let arrow_arr_partner_accessor = ArrowArrayAccessor {};
115141

116-
for arr_ref in batch.columns() {
117-
visit_schema_with_partner(&schema, arr_ref, self, &arrow_arr_partner_accessor)?;
118-
}
142+
let struct_arr = Arc::new(StructArray::from(batch)) as ArrayRef;
143+
visit_struct_with_partner(
144+
&schema.as_struct(),
145+
&struct_arr,
146+
self,
147+
&arrow_arr_partner_accessor,
148+
)?;
119149

120150
Ok(())
121151
}
122152
}
123-

crates/iceberg/src/writer/file_writer/parquet_writer.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ impl FileWriter for ParquetWriter {
525525

526526
self.current_row_num += batch.num_rows();
527527

528-
self.nan_value_count_visitor.compute(self.schema.clone(), batch)?;
528+
// TODO(feniljain): Confirm if this `clone` is okay to perform
529+
let batch_c = batch.clone();
530+
self.nan_value_count_visitor.compute(self.schema.clone(), batch_c)?;
529531

530532
// Lazy initialize the writer
531533
let writer = if let Some(writer) = &mut self.inner_writer {
@@ -1633,7 +1635,7 @@ mod tests {
16331635
MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string());
16341636
let file_name_gen =
16351637
DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
1636-
//
1638+
16371639
// prepare data
16381640
let arrow_schema = {
16391641
let fields = vec![

0 commit comments

Comments
 (0)