Skip to content

Commit d6e9e73

Browse files
committed
add uuid partition type
1 parent f9be89f commit d6e9e73

File tree

2 files changed

+47
-58
lines changed

2 files changed

+47
-58
lines changed

pyiceberg/partitioning.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import uuid
1920
from abc import ABC, abstractmethod
2021
from dataclasses import dataclass
2122
from datetime import date, datetime
@@ -53,7 +54,16 @@
5354
parse_transform,
5455
)
5556
from pyiceberg.typedef import IcebergBaseModel, Record
56-
from pyiceberg.types import DateType, IcebergType, NestedField, PrimitiveType, StructType, TimestampType, TimestamptzType
57+
from pyiceberg.types import (
58+
DateType,
59+
IcebergType,
60+
NestedField,
61+
PrimitiveType,
62+
StructType,
63+
TimestampType,
64+
TimestamptzType,
65+
UUIDType,
66+
)
5767
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros
5868

5969
INITIAL_PARTITION_SPEC_ID = 0
@@ -370,15 +380,15 @@ class PartitionKey:
370380
schema: Schema
371381

372382
@cached_property
373-
def partition(self) -> Record: # partition key in iceberg type
383+
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
374384
iceberg_typed_key_values = {}
375385
for raw_partition_field_value in self.raw_partition_field_values:
376386
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
377387
if len(partition_fields) != 1:
378388
raise ValueError("partition_fields must contain exactly one field.")
379389
partition_field = partition_fields[0]
380390
iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.field.source_id).field_type
381-
iceberg_typed_value = _to_iceberg_type(iceberg_type, raw_partition_field_value.value)
391+
iceberg_typed_value = _to_iceberg_internal_representation(iceberg_type, raw_partition_field_value.value)
382392
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
383393
iceberg_typed_key_values[partition_field.name] = transformed_value
384394
return Record(**iceberg_typed_key_values)
@@ -388,21 +398,26 @@ def to_path(self) -> str:
388398

389399

390400
@singledispatch
391-
def _to_iceberg_type(type: IcebergType, value: Any) -> Any:
401+
def _to_iceberg_internal_representation(type: IcebergType, value: Any) -> Any:
392402
return TypeError(f"Unsupported partition field type: {type}")
393403

394404

395-
@_to_iceberg_type.register(TimestampType)
396-
@_to_iceberg_type.register(TimestamptzType)
405+
@_to_iceberg_internal_representation.register(TimestampType)
406+
@_to_iceberg_internal_representation.register(TimestamptzType)
397407
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
398408
return datetime_to_micros(value) if value is not None else None
399409

400410

401-
@_to_iceberg_type.register(DateType)
411+
@_to_iceberg_internal_representation.register(DateType)
402412
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
403413
return date_to_days(value) if value is not None else None
404414

405415

406-
@_to_iceberg_type.register(PrimitiveType)
416+
@_to_iceberg_internal_representation.register(UUIDType)
417+
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
418+
return str(value) if value is not None else None
419+
420+
421+
@_to_iceberg_internal_representation.register(PrimitiveType)
407422
def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]:
408423
return value

tests/integration/test_partitioning_key.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
18+
import uuid
1819
from datetime import date, datetime, timedelta, timezone
1920
from decimal import Decimal
2021
from typing import Any, List
@@ -50,56 +51,9 @@
5051
StringType,
5152
TimestampType,
5253
TimestamptzType,
54+
UUIDType,
5355
)
5456

55-
# @pytest.fixture(scope="session")
56-
# def session_catalog() -> Catalog:
57-
# return load_catalog(
58-
# "local",
59-
# **{
60-
# "type": "rest",
61-
# "uri": "http://localhost:8181",
62-
# "s3.endpoint": "http://localhost:9000",
63-
# "s3.access-key-id": "admin",
64-
# "s3.secret-access-key": "password",
65-
# },
66-
# )
67-
68-
69-
# @pytest.fixture(scope="session")
70-
# def spark() -> SparkSession:
71-
# import importlib.metadata
72-
# import os
73-
74-
# spark_version = ".".join(importlib.metadata.version("pyspark").split(".")[:2])
75-
# scala_version = "2.12"
76-
# iceberg_version = "1.4.3"
77-
78-
# os.environ["PYSPARK_SUBMIT_ARGS"] = (
79-
# f"--packages org.apache.iceberg:iceberg-spark-runtime-{spark_version}_{scala_version}:{iceberg_version},"
80-
# f"org.apache.iceberg:iceberg-aws-bundle:{iceberg_version} pyspark-shell"
81-
# )
82-
# os.environ["AWS_REGION"] = "us-east-1"
83-
# os.environ["AWS_ACCESS_KEY_ID"] = "admin"
84-
# os.environ["AWS_SECRET_ACCESS_KEY"] = "password"
85-
86-
# spark = (
87-
# SparkSession.builder.appName("PyIceberg integration test")
88-
# .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
89-
# .config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
90-
# .config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
91-
# .config("spark.sql.catalog.integration.uri", "http://localhost:8181")
92-
# .config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
93-
# .config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
94-
# .config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000")
95-
# .config("spark.sql.catalog.integration.s3.path-style-access", "true")
96-
# .config("spark.sql.defaultCatalog", "integration")
97-
# .getOrCreate()
98-
# )
99-
100-
# return spark
101-
102-
10357
TABLE_SCHEMA = Schema(
10458
NestedField(field_id=1, name="boolean_field", field_type=BooleanType(), required=False),
10559
NestedField(field_id=2, name="string_field", field_type=StringType(), required=False),
@@ -112,10 +66,10 @@
11266
NestedField(field_id=9, name="timestamptz_field", field_type=TimestamptzType(), required=False),
11367
NestedField(field_id=10, name="date_field", field_type=DateType(), required=False),
11468
# NestedField(field_id=11, name="time", field_type=TimeType(), required=False),
115-
# NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False),
11669
NestedField(field_id=11, name="binary_field", field_type=BinaryType(), required=False),
11770
NestedField(field_id=12, name="fixed_field", field_type=FixedType(16), required=False),
118-
NestedField(field_id=13, name="decimal", field_type=DecimalType(5, 2), required=False),
71+
NestedField(field_id=13, name="decimal_field", field_type=DecimalType(5, 2), required=False),
72+
NestedField(field_id=14, name="uuid_field", field_type=UUIDType(), required=False),
11973
)
12074

12175

@@ -353,6 +307,25 @@
353307
(CAST('2023-01-01' AS DATE), 'Associated string value for date 2023-01-01')
354308
""",
355309
),
310+
(
311+
[PartitionField(source_id=14, field_id=1001, transform=IdentityTransform(), name="uuid_field")],
312+
[uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")],
313+
Record(uuid_field="f47ac10b-58cc-4372-a567-0e02b2c3d479"),
314+
"uuid_field=f47ac10b-58cc-4372-a567-0e02b2c3d479",
315+
f"""CREATE TABLE {identifier} (
316+
uuid_field string,
317+
string_field string
318+
)
319+
USING iceberg
320+
PARTITIONED BY (
321+
identity(uuid_field)
322+
)
323+
""",
324+
f"""INSERT INTO {identifier}
325+
VALUES
326+
('f47ac10b-58cc-4372-a567-0e02b2c3d479', 'Associated string value for UUID f47ac10b-58cc-4372-a567-0e02b2c3d479')
327+
""",
328+
),
356329
(
357330
[PartitionField(source_id=11, field_id=1001, transform=IdentityTransform(), name="binary_field")],
358331
[b'example'],
@@ -770,6 +743,7 @@ def test_partition_key(
770743
partition_spec=spec,
771744
schema=TABLE_SCHEMA,
772745
)
746+
773747
# key.partition is used to write the metadata in DataFile, ManifestFile and all above layers
774748
assert key.partition == expected_partition_record
775749
# key.to_path() generates the hive partitioning part of the to-write parquet file path

0 commit comments

Comments
 (0)