Skip to content

Commit c80f5f7

Browse files
Initial partition evolution
1 parent 567ec49 commit c80f5f7

File tree

2 files changed

+378
-9
lines changed

2 files changed

+378
-9
lines changed

pyiceberg/partitioning.py

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,9 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from abc import ABC, abstractmethod
1920
from functools import cached_property
20-
from typing import (
21-
Any,
22-
Dict,
23-
List,
24-
Optional,
25-
Tuple,
26-
)
21+
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
2722

2823
from pydantic import (
2924
BeforeValidator,
@@ -34,7 +29,18 @@
3429
from typing_extensions import Annotated
3530

3631
from pyiceberg.schema import Schema
37-
from pyiceberg.transforms import Transform, parse_transform
32+
from pyiceberg.transforms import (
33+
BucketTransform,
34+
DayTransform,
35+
HourTransform,
36+
IdentityTransform,
37+
Transform,
38+
TruncateTransform,
39+
UnknownTransform,
40+
VoidTransform,
41+
YearTransform,
42+
parse_transform,
43+
)
3844
from pyiceberg.typedef import IcebergBaseModel
3945
from pyiceberg.types import NestedField, StructType
4046

@@ -85,6 +91,20 @@ def __str__(self) -> str:
8591
"""Return the string representation of the PartitionField class."""
8692
return f"{self.field_id}: {self.name}: {self.transform}({self.source_id})"
8793

94+
def __hash__(self) -> int:
95+
"""Return the hash of the partition field."""
96+
return hash((self.name, self.source_id, self.field_id, repr(self.transform)))
97+
98+
def __eq__(self, other: Any) -> bool:
99+
"""Return True if two partition fields are considered equal, False otherwise."""
100+
return (
101+
isinstance(other, PartitionField)
102+
and other.field_id == self.field_id
103+
and other.name == self.name
104+
and other.source_id == self.source_id
105+
and repr(other.transform) == repr(self.transform)
106+
)
107+
88108

89109
class PartitionSpec(IcebergBaseModel):
90110
"""
@@ -215,3 +235,110 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
215235
)
216236
)
217237
return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID)
238+
239+
240+
T = TypeVar("T")
241+
242+
243+
class PartitionSpecVisitor(Generic[T], ABC):
244+
@abstractmethod
245+
def identity(self, field_id: int, source_name: str, source_id: int) -> T:
246+
"""Visit identity partition field."""
247+
248+
@abstractmethod
249+
def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T:
250+
"""Visit bucket partition field."""
251+
252+
@abstractmethod
253+
def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T:
254+
"""Visit truncate partition field."""
255+
256+
@abstractmethod
257+
def year(self, field_id: int, source_name: str, source_id: int) -> T:
258+
"""Visit year partition field."""
259+
260+
@abstractmethod
261+
def month(self, field_id: int, source_name: str, source_id: int) -> T:
262+
"""Visit month partition field."""
263+
264+
@abstractmethod
265+
def day(self, field_id: int, source_name: str, source_id: int) -> T:
266+
"""Visit day partition field."""
267+
268+
@abstractmethod
269+
def hour(self, field_id: int, source_name: str, source_id: int) -> T:
270+
"""Visit hour partition field."""
271+
272+
@abstractmethod
273+
def always_null(self, field_id: int, source_name: str, source_id: int) -> T:
274+
"""Visit void partition field."""
275+
276+
@abstractmethod
277+
def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T:
278+
"""Visit unknown partition field."""
279+
raise ValueError(f"Unknown transform {transform} is not supported")
280+
281+
282+
class _PartitionNameGenerator(PartitionSpecVisitor[str]):
283+
def identity(self, field_id: int, source_name: str, source_id: int) -> str:
284+
return source_name
285+
286+
def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str:
287+
return source_name + "_bucket_" + str(num_buckets)
288+
289+
def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str:
290+
return source_name + "_trunc_" + str(width)
291+
292+
def year(self, field_id: int, source_name: str, source_id: int) -> str:
293+
return source_name + "_year"
294+
295+
def month(self, field_id: int, source_name: str, source_id: int) -> str:
296+
return source_name + "_month"
297+
298+
def day(self, field_id: int, source_name: str, source_id: int) -> str:
299+
return source_name + "_day"
300+
301+
def hour(self, field_id: int, source_name: str, source_id: int) -> str:
302+
return source_name + "_hour"
303+
304+
def always_null(self, field_id: int, source_name: str, source_id: int) -> str:
305+
return source_name + "_null"
306+
307+
def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str:
308+
return super().unknown(field_id, source_name, source_id, transform)
309+
310+
311+
R = TypeVar("R")
312+
313+
314+
def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]:
315+
results = []
316+
for field in spec.fields:
317+
results.append(_visit_field(schema, field, visitor))
318+
return results
319+
320+
321+
def _visit_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R:
322+
source_name = schema.find_column_name(field.source_id)
323+
if not source_name:
324+
raise ValueError(f"Could not find column with id {field.source_id}")
325+
326+
transform = field.transform
327+
if isinstance(transform, IdentityTransform):
328+
return visitor.identity(field.field_id, source_name, field.source_id)
329+
elif isinstance(transform, BucketTransform):
330+
return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets)
331+
elif isinstance(transform, TruncateTransform):
332+
return visitor.truncate(field.field_id, source_name, field.source_id, transform.width)
333+
elif isinstance(transform, DayTransform):
334+
return visitor.day(field.field_id, source_name, field.source_id)
335+
elif isinstance(transform, HourTransform):
336+
return visitor.hour(field.field_id, source_name, field.source_id)
337+
elif isinstance(transform, YearTransform):
338+
return visitor.year(field.field_id, source_name, field.source_id)
339+
elif isinstance(transform, VoidTransform):
340+
return visitor.always_null(field.field_id, source_name, field.source_id)
341+
elif isinstance(transform, UnknownTransform):
342+
return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform))
343+
else:
344+
raise ValueError(f"Unknown transform {transform}")

0 commit comments

Comments
 (0)