Skip to content

Commit 805c19c

Browse files
committed
PartitionKey; hive partition path; transform key
1 parent 95029bb commit 805c19c

File tree

4 files changed

+196
-64
lines changed

4 files changed

+196
-64
lines changed

pyiceberg/partitioning.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from functools import cached_property
19+
from dataclasses import dataclass
20+
from datetime import date, datetime
21+
from functools import cached_property, singledispatch
2022
from typing import (
2123
Any,
2224
Dict,
@@ -36,7 +38,8 @@
3638
from pyiceberg.schema import Schema
3739
from pyiceberg.transforms import Transform, parse_transform
3840
from pyiceberg.typedef import IcebergBaseModel, Record
39-
from pyiceberg.types import NestedField, StructType
41+
from pyiceberg.types import DateType, IcebergType, NestedField, StructType, TimestampType, TimestamptzType
42+
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros
4043

4144
INITIAL_PARTITION_SPEC_ID = 0
4245
PARTITION_FIELD_ID_START: int = 1000
@@ -97,7 +100,6 @@ class PartitionSpec(IcebergBaseModel):
97100

98101
spec_id: int = Field(alias="spec-id", default=INITIAL_PARTITION_SPEC_ID)
99102
fields: Tuple[PartitionField, ...] = Field(default_factory=tuple)
100-
schema: Schema
101103

102104
def __init__(
103105
self,
@@ -205,7 +207,7 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
205207
value = getattr(data, field_name)
206208

207209
partition_field = self.fields[pos] # partition field
208-
value_str = partition_field.transform.to_human_string(source_type=field_types[pos].field_type, value=value)
210+
value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=value)
209211
value_strs.append(value_str)
210212
field_strs.append(partition_field.name)
211213
pos += 1
@@ -234,3 +236,53 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
234236
)
235237
)
236238
return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID)
239+
240+
241+
@dataclass(frozen=True)
242+
class PartitionFieldValue:
243+
# It seems partition fields could not be nested or have map, list structure
244+
# So instead of using an accessor which was built through schema-visitor (like iceberg-spark does) to fetch the partition value,
245+
# created this simple class for the first iteration.
246+
# Open to discussion and willing to change to conform to row accessors.
247+
source_id: int
248+
value: Any
249+
250+
251+
@dataclass(frozen=True)
252+
class PartitionKey:
253+
raw_partition_field_values: list[PartitionFieldValue]
254+
partition_spec: PartitionSpec
255+
schema: Schema
256+
from functools import cached_property
257+
258+
@cached_property
259+
def partition(self) -> Record: # partition key in iceberg type
260+
iceberg_typed_key_values = {}
261+
for raw_partition_field_value in self.raw_partition_field_values:
262+
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.source_id]
263+
assert len(partition_fields) == 1
264+
partition_field = partition_fields[0]
265+
iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.source_id).field_type
266+
_iceberg_typed_value = iceberg_typed_value(iceberg_type, raw_partition_field_value.value)
267+
transformed_value = partition_field.transform.transform(iceberg_type)(_iceberg_typed_value)
268+
iceberg_typed_key_values[partition_field.name] = transformed_value
269+
return Record(**iceberg_typed_key_values)
270+
271+
def to_path(self) -> str:
272+
return self.partition_spec.partition_to_path(self.partition, self.schema)
273+
274+
275+
@singledispatch
276+
def iceberg_typed_value(type: IcebergType, value: Any) -> Any:
277+
return value
278+
279+
280+
@iceberg_typed_value.register(TimestampType)
281+
@iceberg_typed_value.register(TimestamptzType)
282+
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
283+
return datetime_to_micros(value) if value is not None else None
284+
285+
286+
@iceberg_typed_value.register(DateType)
287+
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
288+
return date_to_days(value) if value is not None else None

pyiceberg/table/__init__.py

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from abc import ABC, abstractmethod
2323
from copy import copy
2424
from dataclasses import dataclass
25-
from datetime import date, datetime
25+
from datetime import datetime
2626
from enum import Enum
2727
from functools import cached_property, singledispatch
2828
from itertools import chain
@@ -67,7 +67,7 @@
6767
write_manifest,
6868
write_manifest_list,
6969
)
70-
from pyiceberg.partitioning import PartitionSpec
70+
from pyiceberg.partitioning import PartitionFieldValue, PartitionKey, PartitionSpec
7171
from pyiceberg.schema import (
7272
PartnerAccessor,
7373
Schema,
@@ -107,7 +107,6 @@
107107
Identifier,
108108
KeyDefaultDict,
109109
Properties,
110-
Record,
111110
)
112111
from pyiceberg.types import (
113112
IcebergType,
@@ -118,7 +117,7 @@
118117
StructType,
119118
)
120119
from pyiceberg.utils.concurrent import ExecutorFactory
121-
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, datetime_to_millis
120+
from pyiceberg.utils.datetime import datetime_to_millis
122121

123122
if TYPE_CHECKING:
124123
import pandas as pd
@@ -2257,7 +2256,7 @@ class WriteTask:
22572256
def generate_data_file_partition_path(self) -> str:
22582257
if self.partition_key is None:
22592258
raise ValueError("Cannot generate partition path based on non-partitioned WriteTask")
2260-
return self.partition_key.to_path(self.schema)
2259+
return self.partition_key.to_path()
22612260

22622261
def generate_data_file_filename(self, extension: str) -> str:
22632262
# Mimics the behavior in the Java API:
@@ -2467,41 +2466,6 @@ class TablePartition:
24672466
arrow_table_partition: pa.Table
24682467

24692468

2470-
@dataclass(frozen=True)
2471-
class PartitionKey:
2472-
raw_partition_key: Record # partition key in raw python type
2473-
partition_spec: PartitionSpec
2474-
2475-
# this only supports identity transform now
2476-
@property
2477-
def partition(self) -> Record: # partition key in iceberg type
2478-
iceberg_typed_key_values = {
2479-
field_name: iceberg_typed_value(getattr(self.raw_partition_key, field_name, None))
2480-
for field_name in self.raw_partition_key._position_to_field_name
2481-
}
2482-
2483-
return Record(**iceberg_typed_key_values)
2484-
2485-
def to_path(self, schema: Schema) -> str:
2486-
return self.partition_spec.partition_to_path(self.partition, schema)
2487-
2488-
2489-
@singledispatch
2490-
def iceberg_typed_value(value: Any) -> Any:
2491-
return value
2492-
2493-
2494-
@iceberg_typed_value.register(datetime)
2495-
def _(value: Any) -> int:
2496-
val = datetime_to_micros(value)
2497-
return val
2498-
2499-
2500-
@iceberg_typed_value.register(date)
2501-
def _(value: Any) -> int:
2502-
return date_to_days(value)
2503-
2504-
25052469
def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
25062470
order = 'ascending' if not reverse else 'descending'
25072471
null_placement = 'at_start' if reverse else 'at_end'
@@ -2538,15 +2502,35 @@ def _get_partition_columns(iceberg_table: Table, arrow_table: pa.Table) -> list[
25382502
return partition_cols
25392503

25402504

2541-
def _get_partition_key(
2542-
arrow_table: pa.Table, partition_columns: list[str], offset: int, partition_spec: PartitionSpec
2543-
) -> PartitionKey:
2544-
# todo: Instead of fetching partition keys one at a time, try filtering by a mask made of offsets, and convert to py together,
2545-
# possibly slightly more efficient.
2546-
return PartitionKey(
2547-
raw_partition_key=Record(**{col: arrow_table.column(col)[offset].as_py() for col in partition_columns}),
2548-
partition_spec=partition_spec,
2549-
)
2505+
def _get_table_partitions(
2506+
arrow_table: pa.Table,
2507+
partition_spec: PartitionSpec,
2508+
schema: Schema,
2509+
slice_instructions: list[dict[str, Any]],
2510+
) -> list[TablePartition]:
2511+
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset'])
2512+
2513+
partition_fields = partition_spec.fields
2514+
2515+
offsets = [inst["offset"] for inst in sorted_slice_instructions]
2516+
projected_and_filtered = {
2517+
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
2518+
.take(offsets)
2519+
.to_pylist()
2520+
for partition_field in partition_fields
2521+
}
2522+
2523+
table_partitions = []
2524+
for inst in sorted_slice_instructions:
2525+
partition_slice = arrow_table.slice(**inst)
2526+
fieldvalues = [
2527+
PartitionFieldValue(partition_field.source_id, projected_and_filtered[partition_field.source_id][inst["offset"]])
2528+
for partition_field in partition_fields
2529+
]
2530+
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
2531+
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
2532+
2533+
return table_partitions
25502534

25512535

25522536
def _partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePartition]:
@@ -2584,7 +2568,7 @@ def _partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePar
25842568
reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
25852569
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()
25862570

2587-
slice_instructions = []
2571+
slice_instructions: list[dict[str, Any]] = []
25882572
last = len(reversed_indices)
25892573
reversed_indices_size = len(reversed_indices)
25902574
ptr = 0
@@ -2595,13 +2579,10 @@ def _partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePar
25952579
last = reversed_indices[ptr]
25962580
ptr = ptr + group_size
25972581

2598-
table_partitions: list[TablePartition] = [
2599-
TablePartition(
2600-
partition_key=_get_partition_key(arrow_table, partition_columns, inst["offset"], iceberg_table.spec()),
2601-
arrow_table_partition=arrow_table.slice(**inst),
2602-
)
2603-
for inst in slice_instructions
2604-
]
2582+
table_partitions: list[TablePartition] = _get_table_partitions(
2583+
arrow_table, iceberg_table.spec(), iceberg_table.schema(), slice_instructions
2584+
)
2585+
26052586
return table_partitions
26062587

26072588

tests/integration/test_partitioned_writes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,6 @@ def test_query_filter_null_partitioned(spark: SparkSession, part_col: str, forma
360360
for col in TEST_DATA_WITH_NULL.keys():
361361
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 rows for {col}"
362362

363-
spark.sql(f"select path from {identifier}.manifests").show(20, False)
364-
spark.sql(f"select path from {identifier}.manifests").collect()
365-
366363

367364
@pytest.mark.integration
368365
@pytest.mark.parametrize(

tests/table/test_init.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,3 +982,105 @@ def test_correct_schema() -> None:
982982
_ = t.scan(snapshot_id=-1).projection()
983983

984984
assert "Snapshot not found: -1" in str(exc_info.value)
985+
986+
import pytz
987+
from datetime import date, datetime
988+
TEST_DATA_WITH_NULL = {
989+
'bool': [False, None, True],
990+
'string': ['a', None, 'z'],
991+
# Go over the 16 bytes to kick in truncation
992+
'string_long': ['a' * 22, None, 'z' * 22],
993+
'int': [1, None, 9],
994+
'long': [1, None, 9],
995+
'float': [0.0, None, 0.9],
996+
'double': [0.0, None, 0.9],
997+
'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
998+
'timestamptz': [
999+
datetime(2023, 1, 1, 19, 25, 00, tzinfo=pytz.timezone('America/New_York')),
1000+
None,
1001+
datetime(2023, 3, 1, 19, 25, 00, tzinfo=pytz.timezone('America/New_York')),
1002+
],
1003+
'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
1004+
# Not supported by Spark
1005+
# 'time': [time(1, 22, 0), None, time(19, 25, 0)],
1006+
# Not natively supported by Arrow
1007+
# 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes],
1008+
'binary': [b'\01', None, b'\22'],
1009+
'fixed': [
1010+
uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
1011+
None,
1012+
uuid.UUID('11111111-1111-1111-1111-111111111111').bytes,
1013+
],
1014+
}
1015+
import pyarrow as pa
1016+
@pytest.fixture(scope="session")
1017+
def arrow_table_with_null() -> pa.Table:
1018+
"""PyArrow table with all kinds of columns"""
1019+
pa_schema = pa.schema([
1020+
("bool", pa.bool_()),
1021+
("string", pa.string()),
1022+
("string_long", pa.string()),
1023+
("int", pa.int32()),
1024+
("long", pa.int64()),
1025+
("float", pa.float32()),
1026+
("double", pa.float64()),
1027+
("timestamp", pa.timestamp(unit="us")),
1028+
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
1029+
("date", pa.date32()),
1030+
# Not supported by Spark
1031+
# ("time", pa.time64("us")),
1032+
# Not natively supported by Arrow
1033+
# ("uuid", pa.fixed(16)),
1034+
("binary", pa.binary()),
1035+
("fixed", pa.binary(16)),
1036+
])
1037+
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)
1038+
1039+
from pyiceberg.schema import Schema
1040+
from pyiceberg.transforms import IdentityTransform, DayTransform, MonthTransform
1041+
from pyiceberg.types import (
1042+
BinaryType,
1043+
BooleanType,
1044+
DateType,
1045+
DoubleType,
1046+
FixedType,
1047+
FloatType,
1048+
IntegerType,
1049+
LongType,
1050+
NestedField,
1051+
StringType,
1052+
TimestampType,
1053+
TimestamptzType,
1054+
)
1055+
TABLE_SCHEMA = Schema(
1056+
NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False),
1057+
NestedField(field_id=2, name="string", field_type=StringType(), required=False),
1058+
NestedField(field_id=3, name="string_long", field_type=StringType(), required=False),
1059+
NestedField(field_id=4, name="int", field_type=IntegerType(), required=False),
1060+
NestedField(field_id=5, name="long", field_type=LongType(), required=False),
1061+
NestedField(field_id=6, name="float", field_type=FloatType(), required=False),
1062+
NestedField(field_id=7, name="double", field_type=DoubleType(), required=False),
1063+
NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False),
1064+
NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False),
1065+
NestedField(field_id=10, name="date", field_type=DateType(), required=False),
1066+
# NestedField(field_id=11, name="time", field_type=TimeType(), required=False),
1067+
# NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False),
1068+
NestedField(field_id=11, name="binary", field_type=BinaryType(), required=False),
1069+
NestedField(field_id=12, name="fixed", field_type=FixedType(16), required=False),
1070+
)
1071+
1072+
@pytest.mark.mexico
1073+
def test_partition_key(arrow_table_with_null) -> None:
1074+
from pyiceberg.table import PartitionKeyNew, PartitionFieldValue
1075+
from pyiceberg.partitioning import PartitionField, PartitionSpec
1076+
1077+
spec = PartitionSpec(
1078+
PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="test_partition_field")
1079+
)
1080+
1081+
1082+
key = PartitionKeyNew(raw_partition_key_values = [PartitionFieldValue(source_id = 8, value = datetime(2023, 1, 1,11,55,59))], partition_spec = spec, schema = TABLE_SCHEMA)
1083+
print(key.partition)
1084+
print("-----------")
1085+
print(key.to_path())
1086+
print("1")

0 commit comments

Comments
 (0)