Skip to content

Commit d03f058

Browse files
committed
partitioned append with identity transform
1 parent a750b4b commit d03f058

File tree

4 files changed

+623
-65
lines changed

4 files changed

+623
-65
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ test-integration:
4242
docker-compose -f dev/docker-compose-integration.yml up -d
4343
sleep 10
4444
docker-compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py
45-
poetry run pytest tests/ -v -m integration ${PYTEST_ARGS}
45+
poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} -s
4646

4747
test-integration-rebuild:
4848
docker-compose -f dev/docker-compose-integration.yml kill

pyiceberg/io/pyarrow.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126
from pyiceberg.table import WriteTask
127127
from pyiceberg.table.name_mapping import NameMapping
128128
from pyiceberg.transforms import TruncateTransform
129-
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
129+
from pyiceberg.typedef import EMPTY_DICT, Properties
130130
from pyiceberg.types import (
131131
BinaryType,
132132
BooleanType,
@@ -1686,15 +1686,13 @@ def fill_parquet_file_metadata(
16861686

16871687
lower_bounds = {}
16881688
upper_bounds = {}
1689-
16901689
for k, agg in col_aggs.items():
16911690
_min = agg.min_as_bytes()
16921691
if _min is not None:
16931692
lower_bounds[k] = _min
16941693
_max = agg.max_as_bytes()
16951694
if _max is not None:
16961695
upper_bounds[k] = _max
1697-
16981696
for field_id in invalidate_col:
16991697
del lower_bounds[field_id]
17001698
del upper_bounds[field_id]
@@ -1711,45 +1709,37 @@ def fill_parquet_file_metadata(
17111709

17121710

17131711
def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
1714-
task = next(tasks)
1715-
1716-
try:
1717-
_ = next(tasks)
1718-
# If there are more tasks, raise an exception
1719-
raise NotImplementedError("Only unpartitioned writes are supported: https://github.com/apache/iceberg-python/issues/208")
1720-
except StopIteration:
1721-
pass
1722-
1723-
file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
1724-
file_schema = schema_to_pyarrow(table.schema())
1725-
1726-
collected_metrics: List[pq.FileMetaData] = []
1727-
fo = table.io.new_output(file_path)
1728-
with fo.create(overwrite=True) as fos:
1729-
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", metadata_collector=collected_metrics) as writer:
1730-
writer.write_table(task.df)
1731-
1732-
data_file = DataFile(
1733-
content=DataFileContent.DATA,
1734-
file_path=file_path,
1735-
file_format=FileFormat.PARQUET,
1736-
partition=Record(),
1737-
file_size_in_bytes=len(fo),
1738-
sort_order_id=task.sort_order_id,
1739-
# Just copy these from the table for now
1740-
spec_id=table.spec().spec_id,
1741-
equality_ids=None,
1742-
key_metadata=None,
1743-
)
1712+
for task in tasks:
1713+
file_path = f'{table.location()}/data/{task.generate_data_file_path("parquet")}'
1714+
file_schema = schema_to_pyarrow(table.schema())
1715+
1716+
collected_metrics: List[pq.FileMetaData] = []
1717+
fo = table.io.new_output(file_path)
1718+
with fo.create(overwrite=True) as fos:
1719+
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", metadata_collector=collected_metrics) as writer:
1720+
writer.write_table(task.df)
1721+
1722+
data_file = DataFile(
1723+
content=DataFileContent.DATA,
1724+
file_path=file_path,
1725+
file_format=FileFormat.PARQUET,
1726+
partition=task.partition,
1727+
file_size_in_bytes=len(fo),
1728+
sort_order_id=task.sort_order_id,
1729+
# Just copy these from the table for now
1730+
spec_id=table.spec().spec_id,
1731+
equality_ids=None,
1732+
key_metadata=None,
1733+
)
17441734

1745-
if len(collected_metrics) != 1:
1746-
# One file has been written
1747-
raise ValueError(f"Expected 1 entry, got: {collected_metrics}")
1735+
if len(collected_metrics) != 1:
1736+
# One file has been written
1737+
raise ValueError(f"Expected 1 entry, got: {collected_metrics}")
17481738

1749-
fill_parquet_file_metadata(
1750-
data_file=data_file,
1751-
parquet_metadata=collected_metrics[0],
1752-
stats_columns=compute_statistics_plan(table.schema(), table.properties),
1753-
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
1754-
)
1755-
return iter([data_file])
1739+
fill_parquet_file_metadata(
1740+
data_file=data_file,
1741+
parquet_metadata=collected_metrics[0],
1742+
stats_columns=compute_statistics_plan(table.schema(), table.properties),
1743+
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
1744+
)
1745+
yield data_file

pyiceberg/table/__init__.py

Lines changed: 144 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
Identifier,
108108
KeyDefaultDict,
109109
Properties,
110+
Record,
110111
)
111112
from pyiceberg.types import (
112113
IcebergType,
@@ -940,9 +941,6 @@ def append(self, df: pa.Table) -> None:
940941
if not isinstance(df, pa.Table):
941942
raise ValueError(f"Expected PyArrow table, got: {df}")
942943

943-
if len(self.spec().fields) > 0:
944-
raise ValueError("Cannot write to partitioned tables")
945-
946944
if len(self.sort_order().fields) > 0:
947945
raise ValueError("Cannot write to tables with a sort-order")
948946

@@ -2254,14 +2252,28 @@ class WriteTask:
22542252
task_id: int
22552253
df: pa.Table
22562254
sort_order_id: Optional[int] = None
2255+
partition: Optional[Record] = None
22572256

2258-
# Later to be extended with partition information
2257+
def generate_data_file_partition_path(self) -> str:
2258+
if self.partition is None:
2259+
raise ValueError("Cannot generate partition path based on non-partitioned WriteTask")
2260+
partition_strings = []
2261+
for field in self.partition._position_to_field_name:
2262+
value = getattr(self.partition, field)
2263+
partition_strings.append(f"{field}={value}")
2264+
return "/".join(partition_strings)
22592265

22602266
def generate_data_file_filename(self, extension: str) -> str:
22612267
# Mimics the behavior in the Java API:
22622268
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
22632269
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
22642270

2271+
def generate_data_file_path(self, extension: str) -> str:
2272+
if self.partition:
2273+
return f"{self.generate_data_file_partition_path()}/{self. generate_data_file_filename(extension)}"
2274+
else:
2275+
return self.generate_data_file_filename(extension)
2276+
22652277

22662278
def _new_manifest_path(location: str, num: int, commit_uuid: uuid.UUID) -> str:
22672279
return f'{location}/metadata/{commit_uuid}-m{num}.avro'
@@ -2273,23 +2285,6 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int,
22732285
return f'{location}/metadata/snap-{snapshot_id}-{attempt}-{commit_uuid}.avro'
22742286

22752287

2276-
def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
2277-
from pyiceberg.io.pyarrow import write_file
2278-
2279-
if len(table.spec().fields) > 0:
2280-
raise ValueError("Cannot write to partitioned tables")
2281-
2282-
if len(table.sort_order().fields) > 0:
2283-
raise ValueError("Cannot write to tables with a sort-order")
2284-
2285-
write_uuid = uuid.uuid4()
2286-
counter = itertools.count(0)
2287-
2288-
# This is an iter, so we don't have to materialize everything every time
2289-
# This will be more relevant when we start doing partitioned writes
2290-
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]))
2291-
2292-
22932288
class _MergingSnapshotProducer:
22942289
_operation: Operation
22952290
_table: Table
@@ -2467,3 +2462,131 @@ def commit(self) -> Snapshot:
24672462
)
24682463

24692464
return snapshot
2465+
2466+
2467+
@dataclass(frozen=True)
2468+
class TablePartition:
2469+
partition_key: Record
2470+
arrow_table_partition: pa.Table
2471+
2472+
2473+
def get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
2474+
order = 'ascending' if not reverse else 'descending'
2475+
null_placement = 'at_start' if reverse else 'at_end'
2476+
return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement}
2477+
2478+
2479+
def group_by_partition_scheme(iceberg_table: Table, arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
2480+
"""Given a table sort it by current partition scheme with all transform functions supported."""
2481+
# todo support hidden partition transform function
2482+
from pyiceberg.transforms import IdentityTransform
2483+
2484+
supported = {IdentityTransform}
2485+
if not all(type(field.transform) in supported for field in iceberg_table.spec().fields if field in partition_columns):
2486+
raise ValueError(
2487+
f"Not all transforms are supported, get: {[transform in supported for transform in iceberg_table.spec().fields]}."
2488+
)
2489+
2490+
# only works for identity
2491+
sort_options = get_partition_sort_order(partition_columns, reverse=False)
2492+
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement'])
2493+
return sorted_arrow_table
2494+
2495+
2496+
def get_partition_columns(iceberg_table: Table, arrow_table: pa.Table) -> list[str]:
2497+
arrow_table_cols = set(arrow_table.column_names)
2498+
partition_cols = []
2499+
for transform_field in iceberg_table.spec().fields:
2500+
column_name = iceberg_table.schema().find_column_name(transform_field.source_id)
2501+
if not column_name:
2502+
raise ValueError(f"{transform_field=} could not be found in {iceberg_table.schema()}.")
2503+
if column_name not in arrow_table_cols:
2504+
continue
2505+
partition_cols.append(column_name)
2506+
return partition_cols
2507+
2508+
2509+
def get_partition_key(arrow_table: pa.Table, partition_columns: list[str], offset: int) -> Record:
2510+
# todo: Instead of fetching partition keys one at a time, try filtering by a mask made of offsets, and convert to py together,
2511+
# possibly slightly more efficient.
2512+
return Record(**{col: arrow_table.column(col)[offset].as_py() for col in partition_columns})
2513+
2514+
2515+
def partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePartition]:
2516+
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
2517+
2518+
Example:
2519+
Input:
2520+
An arrow table with partition key of ['n_legs', 'year'] and with data of
2521+
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
2522+
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
2523+
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
2524+
2525+
The algrithm:
2526+
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
2527+
and null_placement of "at_end".
2528+
This gives the same table as raw input.
2529+
2530+
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
2531+
and null_placement : "at_start".
2532+
This gives:
2533+
[8, 7, 4, 5, 6, 3, 1, 2, 0]
2534+
2535+
Based on this we get partition groups of indices:
2536+
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
2537+
2538+
We then retrieve the partition keys by offsets.
2539+
And slice the arrow table by offsets and lengths of each partition.
2540+
"""
2541+
import pyarrow as pa
2542+
2543+
partition_columns = get_partition_columns(iceberg_table, arrow_table)
2544+
2545+
arrow_table = group_by_partition_scheme(iceberg_table, arrow_table, partition_columns)
2546+
2547+
reversing_sort_order_options = get_partition_sort_order(partition_columns, reverse=True)
2548+
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()
2549+
2550+
slice_instructions = []
2551+
last = len(reversed_indices)
2552+
reversed_indices_size = len(reversed_indices)
2553+
ptr = 0
2554+
while ptr < reversed_indices_size:
2555+
group_size = last - reversed_indices[ptr]
2556+
offset = reversed_indices[ptr]
2557+
slice_instructions.append({"offset": offset, "length": group_size})
2558+
last = reversed_indices[ptr]
2559+
ptr = ptr + group_size
2560+
2561+
table_partitions: list[TablePartition] = [
2562+
TablePartition(
2563+
partition_key=get_partition_key(arrow_table, partition_columns, inst["offset"]),
2564+
arrow_table_partition=arrow_table.slice(**inst),
2565+
)
2566+
for inst in slice_instructions
2567+
]
2568+
return table_partitions
2569+
2570+
2571+
def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
2572+
from pyiceberg.io.pyarrow import write_file # , stubbed
2573+
2574+
if len(table.sort_order().fields) > 0:
2575+
raise ValueError("Cannot write to tables with a sort-order")
2576+
2577+
write_uuid = uuid.uuid4()
2578+
counter = itertools.count(0)
2579+
2580+
if len(table.spec().fields) > 0:
2581+
partitions = partition(table, df)
2582+
yield from write_file(
2583+
table,
2584+
iter([
2585+
WriteTask(write_uuid, next(counter), partition.arrow_table_partition, partition=partition.partition_key)
2586+
for partition in partitions
2587+
]),
2588+
)
2589+
else:
2590+
# This is an iter, so we don't have to materialize everything every time
2591+
# This will be more relevant when we start doing partitioned writes
2592+
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]))

0 commit comments

Comments
 (0)