Skip to content

Commit dd888ec

Browse files
committed
introduce bucket transform
1 parent 93ebd39 commit dd888ec

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

pyiceberg/transforms.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,27 @@ def __repr__(self) -> str:
309309
return f"BucketTransform(num_buckets={self._num_buckets})"
310310

311311
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
312-
raise NotImplementedError()
312+
import pyarrow as pa
313+
from pyiceberg_core import transform as pyiceberg_core_transform
314+
315+
ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)
316+
317+
def bucket(array: ArrayLike) -> ArrayLike:
318+
if isinstance(array, pa.Array):
319+
return pyiceberg_core_transform.bucket(array, self._num_buckets)
320+
elif isinstance(array, pa.ChunkedArray):
321+
result_chunks = []
322+
for arr in array.iterchunks():
323+
result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets))
324+
return pa.chunked_array(result_chunks)
325+
else:
326+
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")
327+
328+
return bucket
329+
330+
@property
331+
def supports_pyarrow_transform(self) -> bool:
332+
return True
313333

314334

315335
class TimeResolution(IntEnum):

tests/test_transforms.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
# pylint: disable=eval-used,protected-access,redefined-outer-name
1818
from datetime import date
1919
from decimal import Decimal
20-
from typing import TYPE_CHECKING, Any, Callable, Optional
20+
from typing import Any, Callable, Optional, Union
2121
from uuid import UUID
2222

2323
import mmh3 as mmh3
24+
import pyarrow as pa
2425
import pytest
2526
from pydantic import (
2627
BeforeValidator,
@@ -112,9 +113,6 @@
112113
timestamptz_to_micros,
113114
)
114115

115-
if TYPE_CHECKING:
116-
import pyarrow as pa
117-
118116

119117
@pytest.mark.parametrize(
120118
"test_input,test_type,expected",
@@ -1840,3 +1838,26 @@ def test_ymd_pyarrow_transforms(
18401838
else:
18411839
with pytest.raises(ValueError):
18421840
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])
1841+
1842+
1843+
@pytest.mark.parametrize(
1844+
"source_type, input_arr, expected, num_buckets",
1845+
[
1846+
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
1847+
(
1848+
IntegerType(),
1849+
pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]),
1850+
pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]),
1851+
10,
1852+
),
1853+
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
1854+
],
1855+
)
1856+
def test_bucket_pyarrow_transforms(
1857+
source_type: PrimitiveType,
1858+
input_arr: Union[pa.Array, pa.ChunkedArray],
1859+
expected: Union[pa.Array, pa.ChunkedArray],
1860+
num_buckets: int,
1861+
) -> None:
1862+
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
1863+
assert expected == transform.pyarrow_transform(source_type)(input_arr)

0 commit comments

Comments
 (0)