Skip to content

Commit 3c643ec

Browse files
committed
feat: ArrowReader processes FileScanTasks concurrently
1 parent fa7cf15 commit 3c643ec

File tree

1 file changed

+127
-46
lines changed

1 file changed

+127
-46
lines changed

crates/iceberg/src/arrow/reader.rs

Lines changed: 127 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,26 @@
1717

1818
//! Parquet file data reader
1919
20-
use crate::error::Result;
20+
use crate::arrow::{arrow_schema_to_schema, get_arrow_datum};
21+
use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor};
22+
use crate::expr::{BoundPredicate, BoundReference};
23+
use crate::io::{FileIO, FileMetadata, FileRead};
24+
use crate::scan::FileScanTask;
25+
use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
26+
use crate::spec::{Datum, SchemaRef};
27+
use crate::Result;
28+
use crate::{Error, ErrorKind};
2129
use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
2230
use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
2331
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
2432
use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef};
25-
use async_stream::try_stream;
2633
use bytes::Bytes;
2734
use fnv::FnvHashSet;
35+
use futures::channel::mpsc::{channel, Sender};
2836
use futures::future::BoxFuture;
2937
use futures::stream::StreamExt;
3038
use futures::{try_join, TryFutureExt};
39+
use futures::{SinkExt, TryStreamExt};
3140
use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter};
3241
use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader};
3342
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, PARQUET_FIELD_ID_META_KEY};
@@ -37,22 +46,18 @@ use std::collections::{HashMap, HashSet};
3746
use std::ops::Range;
3847
use std::str::FromStr;
3948
use std::sync::Arc;
49+
use tokio::spawn;
4050

41-
use crate::arrow::{arrow_schema_to_schema, get_arrow_datum};
42-
use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor};
43-
use crate::expr::{BoundPredicate, BoundReference};
44-
use crate::io::{FileIO, FileMetadata, FileRead};
45-
use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
46-
use crate::spec::{Datum, SchemaRef};
47-
use crate::{Error, ErrorKind};
51+
const CHANNEL_BUFFER_SIZE: usize = 10;
52+
const CONCURRENCY_LIMIT_TASKS: usize = 10;
4853

4954
/// Builder to create ArrowReader
5055
pub struct ArrowReaderBuilder {
5156
batch_size: Option<usize>,
5257
field_ids: Vec<usize>,
5358
file_io: FileIO,
5459
schema: SchemaRef,
55-
predicates: Option<BoundPredicate>,
60+
predicate: Option<BoundPredicate>,
5661
}
5762

5863
impl ArrowReaderBuilder {
@@ -63,7 +68,7 @@ impl ArrowReaderBuilder {
6368
field_ids: vec![],
6469
file_io,
6570
schema,
66-
predicates: None,
71+
predicate: None,
6772
}
6873
}
6974

@@ -82,7 +87,7 @@ impl ArrowReaderBuilder {
8287

8388
/// Sets the predicates to apply to the scan.
8489
pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self {
85-
self.predicates = Some(predicates);
90+
self.predicate = Some(predicates);
8691
self
8792
}
8893

@@ -93,7 +98,7 @@ impl ArrowReaderBuilder {
9398
field_ids: self.field_ids,
9499
schema: self.schema,
95100
file_io: self.file_io,
96-
predicates: self.predicates,
101+
predicate: self.predicate,
97102
}
98103
}
99104
}
@@ -105,64 +110,140 @@ pub struct ArrowReader {
105110
#[allow(dead_code)]
106111
schema: SchemaRef,
107112
file_io: FileIO,
108-
predicates: Option<BoundPredicate>,
113+
predicate: Option<BoundPredicate>,
109114
}
110115

111116
impl ArrowReader {
112117
/// Take a stream of FileScanTasks and reads all the files.
113118
/// Returns a stream of Arrow RecordBatches containing the data from the files
114-
pub fn read(self, mut tasks: FileScanTaskStream) -> crate::Result<ArrowRecordBatchStream> {
115-
let file_io = self.file_io.clone();
119+
pub fn read(self, tasks: FileScanTaskStream) -> Result<ArrowRecordBatchStream> {
120+
let (sender, receiver) = channel(CHANNEL_BUFFER_SIZE);
116121

117122
// Collect Parquet column indices from field ids
118123
let mut collector = CollectFieldIdVisitor {
119124
field_ids: HashSet::default(),
120125
};
121-
if let Some(predicates) = &self.predicates {
126+
if let Some(predicates) = &self.predicate {
122127
visit(&mut collector, predicates)?;
123128
}
124129

125-
Ok(try_stream! {
126-
while let Some(Ok(task)) = tasks.next().await {
127-
let parquet_file = file_io
128-
.new_input(task.data().data_file().file_path())?;
129-
let (parquet_metadata, parquet_reader) = try_join!(parquet_file.metadata(), parquet_file.reader())?;
130-
let arrow_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader);
130+
let tasks = tasks.map(move |task| self.build_file_scan_task_context(task, sender.clone()));
131131

132-
let mut batch_stream_builder = ParquetRecordBatchStreamBuilder::new(arrow_file_reader)
133-
.await?;
132+
spawn(async move {
133+
tasks
134+
.try_for_each_concurrent(CONCURRENCY_LIMIT_TASKS, Self::process_file_scan_task)
135+
.await
136+
});
134137

135-
let parquet_schema = batch_stream_builder.parquet_schema();
136-
let arrow_schema = batch_stream_builder.schema();
137-
let projection_mask = self.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
138-
batch_stream_builder = batch_stream_builder.with_projection(projection_mask);
138+
Ok(receiver.boxed())
139+
}
139140

140-
let parquet_schema = batch_stream_builder.parquet_schema();
141-
let row_filter = self.get_row_filter(parquet_schema, &collector)?;
141+
fn build_file_scan_task_context(
142+
&self,
143+
task: Result<FileScanTask>,
144+
sender: Sender<Result<RecordBatch>>,
145+
) -> Result<FileScanTaskContext> {
146+
Ok(FileScanTaskContext::new(
147+
task?,
148+
self.file_io.clone(),
149+
sender,
150+
self.batch_size,
151+
self.field_ids.clone(),
152+
self.schema.clone(),
153+
self.predicate.clone(),
154+
))
155+
}
156+
157+
async fn process_file_scan_task(mut context: FileScanTaskContext) -> Result<()> {
158+
let file_scan_task = context.take_task();
142159

143-
if let Some(row_filter) = row_filter {
144-
batch_stream_builder = batch_stream_builder.with_row_filter(row_filter);
145-
}
160+
// Collect Parquet column indices from field ids
161+
let mut collector = CollectFieldIdVisitor {
162+
field_ids: HashSet::default(),
163+
};
164+
if let Some(predicate) = &context.predicate {
165+
visit(&mut collector, predicate)?;
166+
}
146167

147-
if let Some(batch_size) = self.batch_size {
148-
batch_stream_builder = batch_stream_builder.with_batch_size(batch_size);
149-
}
168+
let parquet_file = context
169+
.file_io
170+
.new_input(file_scan_task.data().data_file().file_path())?;
171+
let (parquet_metadata, parquet_reader) =
172+
try_join!(parquet_file.metadata(), parquet_file.reader())?;
173+
let arrow_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader);
150174

151-
let mut batch_stream = batch_stream_builder.build()?;
175+
let mut batch_stream_builder =
176+
ParquetRecordBatchStreamBuilder::new(arrow_file_reader).await?;
152177

153-
while let Some(batch) = batch_stream.next().await {
154-
yield batch?;
155-
}
156-
}
178+
let parquet_schema = batch_stream_builder.parquet_schema();
179+
let arrow_schema = batch_stream_builder.schema();
180+
181+
let projection_mask = context.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
182+
batch_stream_builder = batch_stream_builder.with_projection(projection_mask);
183+
184+
let parquet_schema = batch_stream_builder.parquet_schema();
185+
let row_filter = context.get_row_filter(parquet_schema, &collector)?;
186+
187+
if let Some(row_filter) = row_filter {
188+
batch_stream_builder = batch_stream_builder.with_row_filter(row_filter);
189+
}
190+
191+
if let Some(batch_size) = context.batch_size {
192+
batch_stream_builder = batch_stream_builder.with_batch_size(batch_size);
193+
}
194+
195+
let mut batch_stream = batch_stream_builder.build()?;
196+
197+
while let Some(batch) = batch_stream.next().await {
198+
context.sender.send(Ok(batch?)).await?;
199+
}
200+
201+
Ok(())
202+
}
203+
}
204+
205+
struct FileScanTaskContext {
206+
file_scan_task: Option<FileScanTask>,
207+
file_io: FileIO,
208+
sender: Sender<Result<RecordBatch>>,
209+
batch_size: Option<usize>,
210+
field_ids: Vec<usize>,
211+
schema: SchemaRef,
212+
predicate: Option<BoundPredicate>,
213+
}
214+
215+
impl FileScanTaskContext {
216+
fn new(
217+
file_scan_task: FileScanTask,
218+
file_io: FileIO,
219+
sender: Sender<Result<RecordBatch>>,
220+
batch_size: Option<usize>,
221+
field_ids: Vec<usize>,
222+
schema: SchemaRef,
223+
predicate: Option<BoundPredicate>,
224+
) -> Self {
225+
FileScanTaskContext {
226+
file_scan_task: Some(file_scan_task),
227+
file_io,
228+
sender,
229+
batch_size,
230+
field_ids,
231+
schema,
232+
predicate,
157233
}
158-
.boxed())
234+
}
235+
236+
fn take_task(&mut self) -> FileScanTask {
237+
let mut result = None;
238+
std::mem::swap(&mut self.file_scan_task, &mut result);
239+
result.unwrap()
159240
}
160241

161242
fn get_arrow_projection_mask(
162243
&self,
163244
parquet_schema: &SchemaDescriptor,
164245
arrow_schema: &ArrowSchemaRef,
165-
) -> crate::Result<ProjectionMask> {
246+
) -> Result<ProjectionMask> {
166247
if self.field_ids.is_empty() {
167248
Ok(ProjectionMask::all())
168249
} else {
@@ -232,7 +313,7 @@ impl ArrowReader {
232313
parquet_schema: &SchemaDescriptor,
233314
collector: &CollectFieldIdVisitor,
234315
) -> Result<Option<RowFilter>> {
235-
if let Some(predicates) = &self.predicates {
316+
if let Some(predicate) = &self.predicate {
236317
let field_id_map = build_field_id_map(parquet_schema)?;
237318

238319
// Collect Parquet column indices from field ids.
@@ -255,7 +336,7 @@ impl ArrowReader {
255336
// After collecting required leaf column indices used in the predicate,
256337
// creates the projection mask for the Arrow predicates.
257338
let projection_mask = ProjectionMask::leaves(parquet_schema, column_indices.clone());
258-
let predicate_func = visit(&mut converter, predicates)?;
339+
let predicate_func = visit(&mut converter, predicate)?;
259340
let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate_func);
260341
Ok(Some(RowFilter::new(vec![Box::new(arrow_predicate)])))
261342
} else {

0 commit comments

Comments
 (0)