|
16 | 16 | # under the License. |
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | | -from functools import cached_property |
20 | | -from typing import ( |
21 | | - Any, |
22 | | - Dict, |
23 | | - List, |
24 | | - Optional, |
25 | | - Tuple, |
26 | | -) |
| 19 | +from abc import ABC, abstractmethod |
| 20 | +from functools import cached_property, singledispatch |
| 21 | +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar |
27 | 22 |
|
28 | 23 | from pydantic import ( |
29 | 24 | BeforeValidator, |
|
34 | 29 | from typing_extensions import Annotated |
35 | 30 |
|
36 | 31 | from pyiceberg.schema import Schema |
37 | | -from pyiceberg.transforms import Transform, parse_transform |
| 32 | +from pyiceberg.transforms import ( |
| 33 | + BucketTransform, |
| 34 | + DayTransform, |
| 35 | + HourTransform, |
| 36 | + IdentityTransform, |
| 37 | + Transform, |
| 38 | + TruncateTransform, |
| 39 | + UnknownTransform, |
| 40 | + VoidTransform, |
| 41 | + YearTransform, |
| 42 | + parse_transform, |
| 43 | +) |
38 | 44 | from pyiceberg.typedef import IcebergBaseModel |
39 | 45 | from pyiceberg.types import NestedField, StructType |
40 | 46 |
|
@@ -215,3 +221,108 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre |
215 | 221 | ) |
216 | 222 | ) |
217 | 223 | return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) |
| 224 | + |
| 225 | + |
| 226 | +T = TypeVar("T") |
| 227 | + |
| 228 | + |
| 229 | +class PartitionSpecVisitor(Generic[T], ABC): |
| 230 | + @abstractmethod |
| 231 | + def identity(self, field_id: int, source_name: str, source_id: int) -> T: |
| 232 | + """Visit identity partition field.""" |
| 233 | + |
| 234 | + @abstractmethod |
| 235 | + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T: |
| 236 | + """Visit bucket partition field.""" |
| 237 | + |
| 238 | + @abstractmethod |
| 239 | + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T: |
| 240 | + """Visit truncate partition field.""" |
| 241 | + |
| 242 | + @abstractmethod |
| 243 | + def year(self, field_id: int, source_name: str, source_id: int) -> T: |
| 244 | + """Visit year partition field.""" |
| 245 | + |
| 246 | + @abstractmethod |
| 247 | + def month(self, field_id: int, source_name: str, source_id: int) -> T: |
| 248 | + """Visit month partition field.""" |
| 249 | + |
| 250 | + @abstractmethod |
| 251 | + def day(self, field_id: int, source_name: str, source_id: int) -> T: |
| 252 | + """Visit day partition field.""" |
| 253 | + |
| 254 | + @abstractmethod |
| 255 | + def hour(self, field_id: int, source_name: str, source_id: int) -> T: |
| 256 | + """Visit hour partition field.""" |
| 257 | + |
| 258 | + @abstractmethod |
| 259 | + def always_null(self, field_id: int, source_name: str, source_id: int) -> T: |
| 260 | + """Visit void partition field.""" |
| 261 | + |
| 262 | + @abstractmethod |
| 263 | + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T: |
| 264 | + """Visit unknown partition field.""" |
| 265 | + raise ValueError(f"Unknown transform is not supported: {transform}") |
| 266 | + |
| 267 | + |
| 268 | +class _PartitionNameGenerator(PartitionSpecVisitor[str]): |
| 269 | + def identity(self, field_id: int, source_name: str, source_id: int) -> str: |
| 270 | + return source_name |
| 271 | + |
| 272 | + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str: |
| 273 | + return f"{source_name}_bucket_{num_buckets}" |
| 274 | + |
| 275 | + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str: |
| 276 | + return source_name + "_trunc_" + str(width) |
| 277 | + |
| 278 | + def year(self, field_id: int, source_name: str, source_id: int) -> str: |
| 279 | + return source_name + "_year" |
| 280 | + |
| 281 | + def month(self, field_id: int, source_name: str, source_id: int) -> str: |
| 282 | + return source_name + "_month" |
| 283 | + |
| 284 | + def day(self, field_id: int, source_name: str, source_id: int) -> str: |
| 285 | + return source_name + "_day" |
| 286 | + |
| 287 | + def hour(self, field_id: int, source_name: str, source_id: int) -> str: |
| 288 | + return source_name + "_hour" |
| 289 | + |
| 290 | + def always_null(self, field_id: int, source_name: str, source_id: int) -> str: |
| 291 | + return source_name + "_null" |
| 292 | + |
| 293 | + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str: |
| 294 | + return super().unknown(field_id, source_name, source_id, transform) |
| 295 | + |
| 296 | + |
| 297 | +R = TypeVar("R") |
| 298 | + |
| 299 | + |
| 300 | +@singledispatch |
| 301 | +def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]: |
| 302 | + return [_visit_partition_field(schema, field, visitor) for field in spec.fields] |
| 303 | + |
| 304 | + |
| 305 | +def _visit_partition_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R: |
| 306 | + source_name = schema.find_column_name(field.source_id) |
| 307 | + if not source_name: |
| 308 | + raise ValueError(f"Could not find field with id {field.source_id}") |
| 309 | + |
| 310 | + transform = field.transform |
| 311 | + if isinstance(transform, IdentityTransform): |
| 312 | + return visitor.identity(field.field_id, source_name, field.source_id) |
| 313 | + elif isinstance(transform, BucketTransform): |
| 314 | + return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets) |
| 315 | + elif isinstance(transform, TruncateTransform): |
| 316 | + return visitor.truncate(field.field_id, source_name, field.source_id, transform.width) |
| 317 | + elif isinstance(transform, DayTransform): |
| 318 | + return visitor.day(field.field_id, source_name, field.source_id) |
| 319 | + elif isinstance(transform, HourTransform): |
| 320 | + return visitor.hour(field.field_id, source_name, field.source_id) |
| 321 | + elif isinstance(transform, YearTransform): |
| 322 | + return visitor.year(field.field_id, source_name, field.source_id) |
| 323 | + elif isinstance(transform, VoidTransform): |
| 324 | + return visitor.always_null(field.field_id, source_name, field.source_id) |
| 325 | + elif isinstance(transform, UnknownTransform): |
| 326 | + return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform)) |
| 327 | + else: |
| 328 | + raise ValueError(f"Unknown transform {transform}") |
0 commit comments