|
16 | 16 | # under the License. |
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | | -from functools import cached_property |
20 | | -from typing import ( |
21 | | - Any, |
22 | | - Dict, |
23 | | - List, |
24 | | - Optional, |
25 | | - Tuple, |
26 | | -) |
| 19 | +from abc import ABC, abstractmethod |
| 20 | +from functools import cached_property, singledispatch |
| 21 | +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar |
27 | 22 |
|
28 | 23 | from pydantic import ( |
29 | 24 | BeforeValidator, |
|
34 | 29 | from typing_extensions import Annotated |
35 | 30 |
|
36 | 31 | from pyiceberg.schema import Schema |
37 | | -from pyiceberg.transforms import Transform, parse_transform |
| 32 | +from pyiceberg.transforms import ( |
| 33 | + BucketTransform, |
| 34 | + DayTransform, |
| 35 | + HourTransform, |
| 36 | + IdentityTransform, |
| 37 | + Transform, |
| 38 | + TruncateTransform, |
| 39 | + UnknownTransform, |
| 40 | + VoidTransform, |
| 41 | + YearTransform, |
| 42 | + parse_transform, |
| 43 | +) |
38 | 44 | from pyiceberg.typedef import IcebergBaseModel |
39 | 45 | from pyiceberg.types import NestedField, StructType |
40 | 46 |
|
@@ -85,6 +91,20 @@ def __str__(self) -> str: |
85 | 91 | """Return the string representation of the PartitionField class.""" |
86 | 92 | return f"{self.field_id}: {self.name}: {self.transform}({self.source_id})" |
87 | 93 |
|
| 94 | + def __hash__(self) -> int: |
| 95 | + """Return the hash of the partition field.""" |
| 96 | + return hash((self.name, self.source_id, self.field_id, repr(self.transform))) |
| 97 | + |
| 98 | + def __eq__(self, other: Any) -> bool: |
| 99 | + """Return True if two partition fields are considered equal, False otherwise.""" |
| 100 | + return ( |
| 101 | + isinstance(other, PartitionField) |
| 102 | + and other.field_id == self.field_id |
| 103 | + and other.name == self.name |
| 104 | + and other.source_id == self.source_id |
| 105 | + and repr(other.transform) == repr(self.transform) |
| 106 | + ) |
| 107 | + |
88 | 108 |
|
89 | 109 | class PartitionSpec(IcebergBaseModel): |
90 | 110 | """ |
@@ -215,3 +235,111 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre |
215 | 235 | ) |
216 | 236 | ) |
217 | 237 | return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) |
| 238 | + |
| 239 | + |
| 240 | +T = TypeVar("T") |
| 241 | + |
| 242 | + |
| 243 | +class PartitionSpecVisitor(Generic[T], ABC): |
| 244 | + @abstractmethod |
| 245 | + def identity(self, field_id: int, source_name: str, source_id: int) -> T: |
| 246 | + """Visit identity partition field.""" |
| 247 | + |
| 248 | + @abstractmethod |
| 249 | + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T: |
| 250 | + """Visit bucket partition field.""" |
| 251 | + |
| 252 | + @abstractmethod |
| 253 | + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T: |
| 254 | + """Visit truncate partition field.""" |
| 255 | + |
| 256 | + @abstractmethod |
| 257 | + def year(self, field_id: int, source_name: str, source_id: int) -> T: |
| 258 | + """Visit year partition field.""" |
| 259 | + |
| 260 | + @abstractmethod |
| 261 | + def month(self, field_id: int, source_name: str, source_id: int) -> T: |
| 262 | + """Visit month partition field.""" |
| 263 | + |
| 264 | + @abstractmethod |
| 265 | + def day(self, field_id: int, source_name: str, source_id: int) -> T: |
| 266 | + """Visit day partition field.""" |
| 267 | + |
| 268 | + @abstractmethod |
| 269 | + def hour(self, field_id: int, source_name: str, source_id: int) -> T: |
| 270 | + """Visit hour partition field.""" |
| 271 | + |
| 272 | + @abstractmethod |
| 273 | + def always_null(self, field_id: int, source_name: str, source_id: int) -> T: |
| 274 | + """Visit void partition field.""" |
| 275 | + |
| 276 | + @abstractmethod |
| 277 | + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T: |
| 278 | + """Visit unknown partition field.""" |
| 279 | + raise ValueError(f"Unknown transform {transform} is not supported") |
| 280 | + |
| 281 | + |
| 282 | +class _PartitionNameGenerator(PartitionSpecVisitor[str]): |
| 283 | + def identity(self, field_id: int, source_name: str, source_id: int) -> str: |
| 284 | + return source_name |
| 285 | + |
| 286 | + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str: |
| 287 | + return source_name + "_bucket_" + str(num_buckets) |
| 288 | + |
| 289 | + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str: |
| 290 | + return source_name + "_trunc_" + str(width) |
| 291 | + |
| 292 | + def year(self, field_id: int, source_name: str, source_id: int) -> str: |
| 293 | + return source_name + "_year" |
| 294 | + |
| 295 | + def month(self, field_id: int, source_name: str, source_id: int) -> str: |
| 296 | + return source_name + "_month" |
| 297 | + |
| 298 | + def day(self, field_id: int, source_name: str, source_id: int) -> str: |
| 299 | + return source_name + "_day" |
| 300 | + |
| 301 | + def hour(self, field_id: int, source_name: str, source_id: int) -> str: |
| 302 | + return source_name + "_hour" |
| 303 | + |
| 304 | + def always_null(self, field_id: int, source_name: str, source_id: int) -> str: |
| 305 | + return source_name + "_null" |
| 306 | + |
| 307 | + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str: |
| 308 | + return super().unknown(field_id, source_name, source_id, transform) |
| 309 | + |
| 310 | + |
| 311 | +R = TypeVar("R") |
| 312 | + |
| 313 | + |
| 314 | +@singledispatch |
| 315 | +def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]: |
| 316 | + results = [] |
| 317 | + for field in spec.fields: |
| 318 | + results.append(_visit_field(schema, field, visitor)) |
| 319 | + return results |
| 320 | + |
| 321 | + |
| 322 | +def _visit_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R: |
| 323 | + source_name = schema.find_column_name(field.source_id) |
| 324 | + if not source_name: |
| 325 | + raise ValueError(f"Could not find column with id {field.source_id}") |
| 326 | + |
| 327 | + transform = field.transform |
| 328 | + if isinstance(transform, IdentityTransform): |
| 329 | + return visitor.identity(field.field_id, source_name, field.source_id) |
| 330 | + elif isinstance(transform, BucketTransform): |
| 331 | + return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets) |
| 332 | + elif isinstance(transform, TruncateTransform): |
| 333 | + return visitor.truncate(field.field_id, source_name, field.source_id, transform.width) |
| 334 | + elif isinstance(transform, DayTransform): |
| 335 | + return visitor.day(field.field_id, source_name, field.source_id) |
| 336 | + elif isinstance(transform, HourTransform): |
| 337 | + return visitor.hour(field.field_id, source_name, field.source_id) |
| 338 | + elif isinstance(transform, YearTransform): |
| 339 | + return visitor.year(field.field_id, source_name, field.source_id) |
| 340 | + elif isinstance(transform, VoidTransform): |
| 341 | + return visitor.always_null(field.field_id, source_name, field.source_id) |
| 342 | + elif isinstance(transform, UnknownTransform): |
| 343 | + return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform)) |
| 344 | + else: |
| 345 | + raise ValueError(f"Unknown transform {transform}") |
0 commit comments