|
20 | 20 | from abc import ABC, abstractmethod |
21 | 21 | from enum import IntEnum |
22 | 22 | from functools import singledispatch |
23 | | -from typing import Any, Callable, Generic, Optional, TypeVar |
| 23 | +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar |
24 | 24 | from typing import Literal as LiteralType |
25 | 25 | from uuid import UUID |
26 | 26 |
|
|
82 | 82 | from pyiceberg.utils.parsing import ParseNumberFromBrackets |
83 | 83 | from pyiceberg.utils.singleton import Singleton |
84 | 84 |
|
| 85 | +if TYPE_CHECKING: |
| 86 | + import pyarrow as pa |
| 87 | + |
85 | 88 | S = TypeVar("S") |
86 | 89 | T = TypeVar("T") |
87 | 90 |
|
@@ -175,6 +178,13 @@ def __eq__(self, other: Any) -> bool: |
175 | 178 | return self.root == other.root |
176 | 179 | return False |
177 | 180 |
|
| 181 | + @property |
| 182 | + def supports_pyarrow_transform(self) -> bool: |
| 183 | + return False |
| 184 | + |
| 185 | + @abstractmethod |
| 186 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... |
| 187 | + |
178 | 188 |
|
179 | 189 | class BucketTransform(Transform[S, int]): |
180 | 190 | """Base Transform class to transform a value into a bucket partition value. |
@@ -290,6 +300,9 @@ def __repr__(self) -> str: |
290 | 300 | """Return the string representation of the BucketTransform class.""" |
291 | 301 | return f"BucketTransform(num_buckets={self._num_buckets})" |
292 | 302 |
|
| 303 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 304 | + raise NotImplementedError() |
| 305 | + |
293 | 306 |
|
294 | 307 | class TimeResolution(IntEnum): |
295 | 308 | YEAR = 6 |
@@ -349,6 +362,10 @@ def dedup_name(self) -> str: |
349 | 362 | def preserves_order(self) -> bool: |
350 | 363 | return True |
351 | 364 |
|
| 365 | + @property |
| 366 | + def supports_pyarrow_transform(self) -> bool: |
| 367 | + return True |
| 368 | + |
352 | 369 |
|
353 | 370 | class YearTransform(TimeTransform[S]): |
354 | 371 | """Transforms a datetime value into a year value. |
@@ -391,6 +408,21 @@ def __repr__(self) -> str: |
391 | 408 | """Return the string representation of the YearTransform class.""" |
392 | 409 | return "YearTransform()" |
393 | 410 |
|
| 411 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 412 | + import pyarrow as pa |
| 413 | + import pyarrow.compute as pc |
| 414 | + |
| 415 | + if isinstance(source, DateType): |
| 416 | + epoch = datetime.EPOCH_DATE |
| 417 | + elif isinstance(source, TimestampType): |
| 418 | + epoch = datetime.EPOCH_TIMESTAMP |
| 419 | + elif isinstance(source, TimestamptzType): |
| 420 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 421 | + else: |
| 422 | + raise ValueError(f"Cannot apply year transform for type: {source}") |
| 423 | + |
| 424 | + return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None |
| 425 | + |
394 | 426 |
|
395 | 427 | class MonthTransform(TimeTransform[S]): |
396 | 428 | """Transforms a datetime value into a month value. |
@@ -433,6 +465,27 @@ def __repr__(self) -> str: |
433 | 465 | """Return the string representation of the MonthTransform class.""" |
434 | 466 | return "MonthTransform()" |
435 | 467 |
|
| 468 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 469 | + import pyarrow as pa |
| 470 | + import pyarrow.compute as pc |
| 471 | + |
| 472 | + if isinstance(source, DateType): |
| 473 | + epoch = datetime.EPOCH_DATE |
| 474 | + elif isinstance(source, TimestampType): |
| 475 | + epoch = datetime.EPOCH_TIMESTAMP |
| 476 | + elif isinstance(source, TimestamptzType): |
| 477 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 478 | + else: |
| 479 | + raise ValueError(f"Cannot apply month transform for type: {source}") |
| 480 | + |
| 481 | + def month_func(v: pa.Array) -> pa.Array: |
| 482 | + return pc.add( |
| 483 | + pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)), |
| 484 | + pc.add(pc.month(v), pa.scalar(-1)), |
| 485 | + ) |
| 486 | + |
| 487 | + return lambda v: month_func(v) if v is not None else None |
| 488 | + |
436 | 489 |
|
437 | 490 | class DayTransform(TimeTransform[S]): |
438 | 491 | """Transforms a datetime value into a day value. |
@@ -478,6 +531,21 @@ def __repr__(self) -> str: |
478 | 531 | """Return the string representation of the DayTransform class.""" |
479 | 532 | return "DayTransform()" |
480 | 533 |
|
| 534 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 535 | + import pyarrow as pa |
| 536 | + import pyarrow.compute as pc |
| 537 | + |
| 538 | + if isinstance(source, DateType): |
| 539 | + epoch = datetime.EPOCH_DATE |
| 540 | + elif isinstance(source, TimestampType): |
| 541 | + epoch = datetime.EPOCH_TIMESTAMP |
| 542 | + elif isinstance(source, TimestamptzType): |
| 543 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 544 | + else: |
| 545 | + raise ValueError(f"Cannot apply day transform for type: {source}") |
| 546 | + |
| 547 | + return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None |
| 548 | + |
481 | 549 |
|
482 | 550 | class HourTransform(TimeTransform[S]): |
483 | 551 | """Transforms a datetime value into a hour value. |
@@ -515,6 +583,19 @@ def __repr__(self) -> str: |
515 | 583 | """Return the string representation of the HourTransform class.""" |
516 | 584 | return "HourTransform()" |
517 | 585 |
|
| 586 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 587 | + import pyarrow as pa |
| 588 | + import pyarrow.compute as pc |
| 589 | + |
| 590 | + if isinstance(source, TimestampType): |
| 591 | + epoch = datetime.EPOCH_TIMESTAMP |
| 592 | + elif isinstance(source, TimestamptzType): |
| 593 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 594 | + else: |
| 595 | + raise ValueError(f"Cannot apply hour transform for type: {source}") |
| 596 | + |
| 597 | + return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None |
| 598 | + |
518 | 599 |
|
519 | 600 | def _base64encode(buffer: bytes) -> str: |
520 | 601 | """Convert bytes to base64 string.""" |
@@ -585,6 +666,13 @@ def __repr__(self) -> str: |
585 | 666 | """Return the string representation of the IdentityTransform class.""" |
586 | 667 | return "IdentityTransform()" |
587 | 668 |
|
| 669 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 670 | + return lambda v: v |
| 671 | + |
| 672 | + @property |
| 673 | + def supports_pyarrow_transform(self) -> bool: |
| 674 | + return True |
| 675 | + |
588 | 676 |
|
589 | 677 | class TruncateTransform(Transform[S, S]): |
590 | 678 | """A transform for truncating a value to a specified width. |
@@ -725,6 +813,9 @@ def __repr__(self) -> str: |
725 | 813 | """Return the string representation of the TruncateTransform class.""" |
726 | 814 | return f"TruncateTransform(width={self._width})" |
727 | 815 |
|
| 816 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 817 | + raise NotImplementedError() |
| 818 | + |
728 | 819 |
|
729 | 820 | @singledispatch |
730 | 821 | def _human_string(value: Any, _type: IcebergType) -> str: |
@@ -807,6 +898,9 @@ def __repr__(self) -> str: |
807 | 898 | """Return the string representation of the UnknownTransform class.""" |
808 | 899 | return f"UnknownTransform(transform={repr(self._transform)})" |
809 | 900 |
|
| 901 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 902 | + raise NotImplementedError() |
| 903 | + |
810 | 904 |
|
811 | 905 | class VoidTransform(Transform[S, None], Singleton): |
812 | 906 | """A transform that always returns None.""" |
@@ -835,6 +929,9 @@ def __repr__(self) -> str: |
835 | 929 | """Return the string representation of the VoidTransform class.""" |
836 | 930 | return "VoidTransform()" |
837 | 931 |
|
| 932 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 933 | + raise NotImplementedError() |
| 934 | + |
838 | 935 |
|
839 | 936 | def _truncate_number( |
840 | 937 | name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] |
|
0 commit comments