107107 Identifier ,
108108 KeyDefaultDict ,
109109 Properties ,
110+ Record ,
110111)
111112from pyiceberg .types import (
112113 IcebergType ,
@@ -940,9 +941,6 @@ def append(self, df: pa.Table) -> None:
940941 if not isinstance (df , pa .Table ):
941942 raise ValueError (f"Expected PyArrow table, got: { df } " )
942943
943- if len (self .spec ().fields ) > 0 :
944- raise ValueError ("Cannot write to partitioned tables" )
945-
946944 if len (self .sort_order ().fields ) > 0 :
947945 raise ValueError ("Cannot write to tables with a sort-order" )
948946
@@ -2254,14 +2252,28 @@ class WriteTask:
22542252 task_id : int
22552253 df : pa .Table
22562254 sort_order_id : Optional [int ] = None
2255+ partition : Optional [Record ] = None
22572256
2258- # Later to be extended with partition information
2257+ def generate_data_file_partition_path (self ) -> str :
2258+ if self .partition is None :
2259+ raise ValueError ("Cannot generate partition path based on non-partitioned WriteTask" )
2260+ partition_strings = []
2261+ for field in self .partition ._position_to_field_name :
2262+ value = getattr (self .partition , field )
2263+ partition_strings .append (f"{ field } ={ value } " )
2264+ return "/" .join (partition_strings )
22592265
22602266 def generate_data_file_filename (self , extension : str ) -> str :
22612267 # Mimics the behavior in the Java API:
22622268 # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
22632269 return f"00000-{ self .task_id } -{ self .write_uuid } .{ extension } "
22642270
2271+ def generate_data_file_path (self , extension : str ) -> str :
2272+ if self .partition :
2273+ return f"{ self .generate_data_file_partition_path ()} /{ self . generate_data_file_filename (extension )} "
2274+ else :
2275+ return self .generate_data_file_filename (extension )
2276+
22652277
22662278def _new_manifest_path (location : str , num : int , commit_uuid : uuid .UUID ) -> str :
22672279 return f'{ location } /metadata/{ commit_uuid } -m{ num } .avro'
@@ -2273,23 +2285,6 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int,
22732285 return f'{ location } /metadata/snap-{ snapshot_id } -{ attempt } -{ commit_uuid } .avro'
22742286
22752287
2276- def _dataframe_to_data_files (table : Table , df : pa .Table ) -> Iterable [DataFile ]:
2277- from pyiceberg .io .pyarrow import write_file
2278-
2279- if len (table .spec ().fields ) > 0 :
2280- raise ValueError ("Cannot write to partitioned tables" )
2281-
2282- if len (table .sort_order ().fields ) > 0 :
2283- raise ValueError ("Cannot write to tables with a sort-order" )
2284-
2285- write_uuid = uuid .uuid4 ()
2286- counter = itertools .count (0 )
2287-
2288- # This is an iter, so we don't have to materialize everything every time
2289- # This will be more relevant when we start doing partitioned writes
2290- yield from write_file (table , iter ([WriteTask (write_uuid , next (counter ), df )]))
2291-
2292-
22932288class _MergingSnapshotProducer :
22942289 _operation : Operation
22952290 _table : Table
@@ -2467,3 +2462,131 @@ def commit(self) -> Snapshot:
24672462 )
24682463
24692464 return snapshot
2465+
2466+
2467+ @dataclass (frozen = True )
2468+ class TablePartition :
2469+ partition_key : Record
2470+ arrow_table_partition : pa .Table
2471+
2472+
2473+ def get_partition_sort_order (partition_columns : list [str ], reverse : bool = False ) -> dict [str , Any ]:
2474+ order = 'ascending' if not reverse else 'descending'
2475+ null_placement = 'at_start' if reverse else 'at_end'
2476+ return {'sort_keys' : [(column_name , order ) for column_name in partition_columns ], 'null_placement' : null_placement }
2477+
2478+
2479+ def group_by_partition_scheme (iceberg_table : Table , arrow_table : pa .Table , partition_columns : list [str ]) -> pa .Table :
2480+ """Given a table sort it by current partition scheme with all transform functions supported."""
2481+ # todo support hidden partition transform function
2482+ from pyiceberg .transforms import IdentityTransform
2483+
2484+ supported = {IdentityTransform }
2485+ if not all (type (field .transform ) in supported for field in iceberg_table .spec ().fields if field in partition_columns ):
2486+ raise ValueError (
2487+ f"Not all transforms are supported, get: { [transform in supported for transform in iceberg_table .spec ().fields ]} ."
2488+ )
2489+
2490+ # only works for identity
2491+ sort_options = get_partition_sort_order (partition_columns , reverse = False )
2492+ sorted_arrow_table = arrow_table .sort_by (sorting = sort_options ['sort_keys' ], null_placement = sort_options ['null_placement' ])
2493+ return sorted_arrow_table
2494+
2495+
2496+ def get_partition_columns (iceberg_table : Table , arrow_table : pa .Table ) -> list [str ]:
2497+ arrow_table_cols = set (arrow_table .column_names )
2498+ partition_cols = []
2499+ for transform_field in iceberg_table .spec ().fields :
2500+ column_name = iceberg_table .schema ().find_column_name (transform_field .source_id )
2501+ if not column_name :
2502+ raise ValueError (f"{ transform_field = } could not be found in { iceberg_table .schema ()} ." )
2503+ if column_name not in arrow_table_cols :
2504+ continue
2505+ partition_cols .append (column_name )
2506+ return partition_cols
2507+
2508+
2509+ def get_partition_key (arrow_table : pa .Table , partition_columns : list [str ], offset : int ) -> Record :
2510+ # todo: Instead of fetching partition keys one at a time, try filtering by a mask made of offsets, and convert to py together,
2511+ # possibly slightly more efficient.
2512+ return Record (** {col : arrow_table .column (col )[offset ].as_py () for col in partition_columns })
2513+
2514+
2515+ def partition (iceberg_table : Table , arrow_table : pa .Table ) -> Iterable [TablePartition ]:
2516+ """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
2517+
2518+ Example:
2519+ Input:
2520+ An arrow table with partition key of ['n_legs', 'year'] and with data of
2521+ {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
2522+ 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
2523+ 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
2524+
2525+ The algrithm:
2526+ Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
2527+ and null_placement of "at_end".
2528+ This gives the same table as raw input.
2529+
2530+ Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
2531+ and null_placement : "at_start".
2532+ This gives:
2533+ [8, 7, 4, 5, 6, 3, 1, 2, 0]
2534+
2535+ Based on this we get partition groups of indices:
2536+ [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
2537+
2538+ We then retrieve the partition keys by offsets.
2539+ And slice the arrow table by offsets and lengths of each partition.
2540+ """
2541+ import pyarrow as pa
2542+
2543+ partition_columns = get_partition_columns (iceberg_table , arrow_table )
2544+
2545+ arrow_table = group_by_partition_scheme (iceberg_table , arrow_table , partition_columns )
2546+
2547+ reversing_sort_order_options = get_partition_sort_order (partition_columns , reverse = True )
2548+ reversed_indices = pa .compute .sort_indices (arrow_table , ** reversing_sort_order_options ).to_pylist ()
2549+
2550+ slice_instructions = []
2551+ last = len (reversed_indices )
2552+ reversed_indices_size = len (reversed_indices )
2553+ ptr = 0
2554+ while ptr < reversed_indices_size :
2555+ group_size = last - reversed_indices [ptr ]
2556+ offset = reversed_indices [ptr ]
2557+ slice_instructions .append ({"offset" : offset , "length" : group_size })
2558+ last = reversed_indices [ptr ]
2559+ ptr = ptr + group_size
2560+
2561+ table_partitions : list [TablePartition ] = [
2562+ TablePartition (
2563+ partition_key = get_partition_key (arrow_table , partition_columns , inst ["offset" ]),
2564+ arrow_table_partition = arrow_table .slice (** inst ),
2565+ )
2566+ for inst in slice_instructions
2567+ ]
2568+ return table_partitions
2569+
2570+
2571+ def _dataframe_to_data_files (table : Table , df : pa .Table ) -> Iterable [DataFile ]:
2572+ from pyiceberg .io .pyarrow import write_file # , stubbed
2573+
2574+ if len (table .sort_order ().fields ) > 0 :
2575+ raise ValueError ("Cannot write to tables with a sort-order" )
2576+
2577+ write_uuid = uuid .uuid4 ()
2578+ counter = itertools .count (0 )
2579+
2580+ if len (table .spec ().fields ) > 0 :
2581+ partitions = partition (table , df )
2582+ yield from write_file (
2583+ table ,
2584+ iter ([
2585+ WriteTask (write_uuid , next (counter ), partition .arrow_table_partition , partition = partition .partition_key )
2586+ for partition in partitions
2587+ ]),
2588+ )
2589+ else :
2590+ # This is an iter, so we don't have to materialize everything every time
2591+ # This will be more relevant when we start doing partitioned writes
2592+ yield from write_file (table , iter ([WriteTask (write_uuid , next (counter ), df )]))
0 commit comments