|
17 | 17 | # pylint: disable=eval-used,protected-access,redefined-outer-name |
18 | 18 | from datetime import date |
19 | 19 | from decimal import Decimal |
20 | | -from typing import TYPE_CHECKING, Any, Callable, Optional |
| 20 | +from typing import Any, Callable, Optional, Union |
21 | 21 | from uuid import UUID |
22 | 22 |
|
23 | 23 | import mmh3 as mmh3 |
| 24 | +import pyarrow as pa |
24 | 25 | import pytest |
25 | 26 | from pydantic import ( |
26 | 27 | BeforeValidator, |
|
112 | 113 | timestamptz_to_micros, |
113 | 114 | ) |
114 | 115 |
|
115 | | -if TYPE_CHECKING: |
116 | | - import pyarrow as pa |
117 | | - |
118 | 116 |
|
119 | 117 | @pytest.mark.parametrize( |
120 | 118 | "test_input,test_type,expected", |
@@ -1840,3 +1838,26 @@ def test_ymd_pyarrow_transforms( |
1840 | 1838 | else: |
1841 | 1839 | with pytest.raises(ValueError): |
1842 | 1840 | transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) |
| 1841 | + |
| 1842 | + |
| 1843 | +@pytest.mark.parametrize( |
| 1844 | + "source_type, input_arr, expected, num_buckets", |
| 1845 | + [ |
| 1846 | + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), |
| 1847 | + ( |
| 1848 | + IntegerType(), |
| 1849 | + pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), |
| 1850 | + pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), |
| 1851 | + 10, |
| 1852 | + ), |
| 1853 | + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), |
| 1854 | + ], |
| 1855 | +) |
| 1856 | +def test_bucket_pyarrow_transforms( |
| 1857 | + source_type: PrimitiveType, |
| 1858 | + input_arr: Union[pa.Array, pa.ChunkedArray], |
| 1859 | + expected: Union[pa.Array, pa.ChunkedArray], |
| 1860 | + num_buckets: int, |
| 1861 | +) -> None: |
| 1862 | + transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) |
| 1863 | + assert expected == transform.pyarrow_transform(source_type)(input_arr) |
0 commit comments