-
Notifications
You must be signed in to change notification settings - Fork 393
Description
Feature Request / Improvement
Creating from #614 (comment) as suggested.
I have a couple of moderately heavy tables with millions of parquet files and ~100 manifest files, which I routinely need to query using the .files metatable.
The existing implementation goes over the manifests sequentially and takes quite a while constructing the resulting pyarrow.Table.
Since reporting the issue in #614 I have spent a while thinking about it and came up with a temporary solution that works in my case.
I'm using a ProcessPoolExecutor and distributing the manifests each to their own mapper, returning pyarrow.RecordBatch from each. Then I'm using pyarrow.Table.from_batches() to construct the resulting table.
Ideally I would like to process AVRO manifests on per-block basis to further speed things up, but doing so with the ProcessPoolExecutor seems to come with too much overhead.
Before writing my custom code, .files() on a table with 350k files took ~70 seconds, and around ~200 seconds on a 1m file table. The code below, which is not an apples to apples comparison because I only process the data I need, takes less than 5 seconds.
Here's what I did in my particular case to give the general idea:
_FILES_SCHEMA = pa.schema(
[
pa.field('path', pa.string(), nullable=False),
pa.field('event_id', pa.string(), nullable=False),
pa.field('ts_min', pa.int64(), nullable=False),
pa.field('ts_max', pa.int64(), nullable=False),
]
)
def get_message_table_files(
table: pyiceberg.table.Table,
) -> pa.Table:
schema = table.metadata.schema()
snapshot = table.current_snapshot()
if not snapshot:
return pa.Table.from_pylist(
[],
schema=_FILES_SCHEMA,
)
with ProcessPoolExecutor() as pool:
return pa.Table.from_batches(
pool.map(
partial(_process_manifest, schema, table.io),
snapshot.manifests(table.io),
),
schema=_FILES_SCHEMA,
)
def _process_manifest(
table_schema: Schema,
io: FileIO,
manifest: ManifestFile,
) -> pa.RecordBatch:
ts_field = table_schema.find_field('ts')
rows = [
dict(
path=entry.data_file.file_path,
event_id=entry.data_file.partition.event_id,
ts_min=from_bytes(
ts_field.field_type,
entry.data_file.lower_bounds.get(ts_field.field_id),
),
ts_max=from_bytes(
ts_field.field_type,
entry.data_file.upper_bounds.get(ts_field.field_id),
),
)
for entry in manifest.fetch_manifest_entry(io)
if entry.data_file.file_format == FileFormat.PARQUET
and entry.status != ManifestEntryStatus.DELETED
]
return pa.RecordBatch.from_pylist(
rows,
schema=_FILES_SCHEMA,
)