Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@
# under the License.
from __future__ import annotations

import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import date, datetime
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
from urllib.parse import quote

from pydantic import (
BeforeValidator,
Expand All @@ -41,8 +53,18 @@
YearTransform,
parse_transform,
)
from pyiceberg.typedef import IcebergBaseModel
from pyiceberg.types import NestedField, StructType
from pyiceberg.typedef import IcebergBaseModel, Record
from pyiceberg.types import (
DateType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
TimestamptzType,
UUIDType,
)
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros

INITIAL_PARTITION_SPEC_ID = 0
PARTITION_FIELD_ID_START: int = 1000
Expand Down Expand Up @@ -199,6 +221,23 @@ def partition_type(self, schema: Schema) -> StructType:
nested_fields.append(NestedField(field.field_id, field.name, result_type, required=False))
return StructType(*nested_fields)

def partition_to_path(self, data: Record, schema: Schema) -> str:
partition_type = self.partition_type(schema)
field_types = partition_type.fields

field_strs = []
value_strs = []
for pos, value in enumerate(data.record_fields()):
partition_field = self.fields[pos]
value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=value)

value_str = quote(value_str, safe='')
value_strs.append(value_str)
field_strs.append(partition_field.name)

path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs)])
return path


UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)

Expand Down Expand Up @@ -326,3 +365,59 @@ def _visit_partition_field(schema: Schema, field: PartitionField, visitor: Parti
return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform))
else:
raise ValueError(f"Unknown transform {transform}")


@dataclass(frozen=True)
class PartitionFieldValue:
field: PartitionField
value: Any


@dataclass(frozen=True)
class PartitionKey:
raw_partition_field_values: List[PartitionFieldValue]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark builds a row accessor that takes in an arrow table row and converts it to key values. The accessor seems a little unnecessary since the partition field could not be nested or a map/list, so here the class just uses a naive list of field-value pairs. Willing to change it if this is inappropriate.

partition_spec: PartitionSpec
schema: Schema

@cached_property
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
iceberg_typed_key_values = {}
for raw_partition_field_value in self.raw_partition_field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError("partition_fields must contain exactly one field.")
partition_field = partition_fields[0]
iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.field.source_id).field_type
iceberg_typed_value = _to_partition_representation(iceberg_type, raw_partition_field_value.value)
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
iceberg_typed_key_values[partition_field.name] = transformed_value
return Record(**iceberg_typed_key_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're now getting into the realm of premature optimization, but ideally you don't need to set the names of the keys. The concept of a Record is that is only contains the data. Just below: self.partition_spec.partition_to_path(self.partition, self.schema) you can see that you both pass in the partition, and the schema itself. The positions of the schema should match with the data.

Copy link
Contributor Author

@jqin61 jqin61 Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Fokko, thanks for the guidance! My intention of adding the keys is because this PartitionKey.partition is not only used for generating the file path but also used to initiate the Datafile.partition in the io.pyarrow.write_file(). As the integration test shows,

snapshot.manifests(iceberg_table.io)[0].fetch_manifest_entry(iceberg_table.io)[0].data_file.partition

prints

Record(timestamp_field=1672574401000000)

So I assume this data_file.partition is Record with keys.
Let me know what you think about it, thank you!


def to_path(self) -> str:
return self.partition_spec.partition_to_path(self.partition, self.schema)


@singledispatch
def _to_partition_representation(type: IcebergType, value: Any) -> Any:
return TypeError(f"Unsupported partition field type: {type}")


@_to_partition_representation.register(TimestampType)
@_to_partition_representation.register(TimestamptzType)
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
return datetime_to_micros(value) if value is not None else None


@_to_partition_representation.register(DateType)
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
return date_to_days(value) if value is not None else None


@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None


@_to_partition_representation.register(PrimitiveType)
def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]:
return value
5 changes: 5 additions & 0 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,11 @@ def _(value: int, _type: IcebergType) -> str:
return _int_to_human_string(_type, value)


@_human_string.register(bool)
def _(value: bool, _type: IcebergType) -> str:
return str(value).lower()


@singledispatch
def _int_to_human_string(_type: IcebergType, value: int) -> str:
return str(value)
Expand Down
51 changes: 50 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
import boto3
import pytest
from moto import mock_aws
from pyspark.sql import SparkSession

from pyiceberg import schema
from pyiceberg.catalog import Catalog
from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.expressions import BoundReference
from pyiceberg.io import (
Expand Down Expand Up @@ -1925,3 +1926,51 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture(scope="session")
def session_catalog() -> Catalog:
return load_catalog(
"local",
**{
"type": "rest",
"uri": "http://localhost:8181",
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
},
)


@pytest.fixture(scope="session")
def spark() -> SparkSession:
import importlib.metadata
import os

spark_version = ".".join(importlib.metadata.version("pyspark").split(".")[:2])
scala_version = "2.12"
iceberg_version = "1.4.3"

os.environ["PYSPARK_SUBMIT_ARGS"] = (
f"--packages org.apache.iceberg:iceberg-spark-runtime-{spark_version}_{scala_version}:{iceberg_version},"
f"org.apache.iceberg:iceberg-aws-bundle:{iceberg_version} pyspark-shell"
)
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_ACCESS_KEY_ID"] = "admin"
os.environ["AWS_SECRET_ACCESS_KEY"] = "password"

spark = (
SparkSession.builder.appName("PyIceberg integration test")
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
.config("spark.sql.catalog.integration.uri", "http://localhost:8181")
.config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
.config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
.config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000")
.config("spark.sql.catalog.integration.s3.path-style-access", "true")
.config("spark.sql.defaultCatalog", "integration")
.getOrCreate()
)

return spark
Loading