From 0cad231cdeac7c9833e08289a2dca70a452fa2c8 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 4 May 2024 02:20:39 +0000 Subject: [PATCH 01/21] checkpoint --- pyiceberg/transforms.py | 25 ++++++ .../test_writes/test_partitioned_writes.py | 89 +++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 6dcae59e49..c75f7861c0 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -433,6 +433,31 @@ def __repr__(self) -> str: """Return the string representation of the MonthTransform class.""" return "MonthTransform()" + def pyarrow_transform(self, source: IcebergType) -> Callable: + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + + def month_func(v: Any) -> int: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(date(1970, 1, 1)), v), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + + elif isinstance(source, (TimestampType, TimestamptzType)): + + def month_func(v: Any) -> int: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(datetime(1970, 1, 1)), pc.local_timestamp(v)), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + + else: + raise ValueError(f"Cannot apply month transform for type: {source}") + + return lambda v: month_func(v) if v is not None else None + class DayTransform(TimeTransform[S]): """Transforms a datetime value into a day value. diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index d84b9745a7..91856a4790 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -16,6 +16,8 @@ # under the License. # pylint:disable=redefined-outer-name +from datetime import date, datetime, timezone + import pyarrow as pa import pytest from pyspark.sql import SparkSession @@ -36,6 +38,54 @@ from utils import TABLE_SCHEMA, _create_table +@pytest.fixture(scope="session") +def arrow_table_dates() -> pa.Table: + """Pyarrow table with only null values.""" + TEST_DATES = [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1)] + return pa.Table.from_pydict( + {"dates": TEST_DATES}, + schema=pa.schema([ + ("dates", pa.date32()), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_timestamp() -> pa.Table: + """Pyarrow table with only null values.""" + TEST_DATETIMES = [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + ] + return pa.Table.from_pydict( + {"dates": TEST_DATETIMES}, + schema=pa.schema([ + ("timestamp", pa.timestamp(unit="us")), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_timestamptz() -> pa.Table: + """Pyarrow table with only null values.""" + TEST_DATETIMES_WITH_TZ = [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + ] + return pa.Table.from_pydict( + {"dates": TEST_DATETIMES_WITH_TZ}, + schema=pa.schema([ + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + @pytest.mark.integration @pytest.mark.parametrize( "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] @@ -384,3 +434,42 @@ def test_unsupported_transform( with pytest.raises(ValueError, match="All transforms are not supported.*"): tbl.append(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_time_transform_partitioned_table( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[], + partition_spec=partition_spec, + ) + # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + df = spark.table(identifier) + assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" + # expecting 6 files: first append with [A], [B], [C], second append with [A, A], [B, B], [C, C] + rows = spark.sql(f"select partition from {identifier}.files").collect() + assert len(rows) == 6 From 96e55334d95f7f3d7aea6f6d8c220b2d1a7aa73d Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 5 May 2024 16:27:57 +0000 Subject: [PATCH 02/21] checkpoint2 --- pyiceberg/transforms.py | 78 ++++++++++++++----- .../test_writes/test_partitioned_writes.py | 76 +++--------------- tests/test_transforms.py | 71 ++++++++++++++++- 3 files changed, 141 insertions(+), 84 deletions(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index c75f7861c0..0cf26fe2a2 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from enum import IntEnum from functools import singledispatch -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar from typing import Literal as LiteralType from uuid import UUID @@ -82,6 +82,9 @@ from pyiceberg.utils.parsing import ParseNumberFromBrackets from pyiceberg.utils.singleton import Singleton +if TYPE_CHECKING: + import pyarrow as pa + S = TypeVar("S") T = TypeVar("T") @@ -391,6 +394,21 @@ def __repr__(self) -> str: """Return the string representation of the YearTransform class.""" return "YearTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply year transform for type: {source}") + + return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None + class MonthTransform(TimeTransform[S]): """Transforms a datetime value into a month value. @@ -433,29 +451,25 @@ def __repr__(self) -> str: """Return the string representation of the MonthTransform class.""" return "MonthTransform()" - def pyarrow_transform(self, source: IcebergType) -> Callable: + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": import pyarrow as pa import pyarrow.compute as pc - - if isinstance(source, DateType): - - def month_func(v: Any) -> int: - return pc.add( - pc.multiply(pc.years_between(pa.scalar(date(1970, 1, 1)), v), pa.scalar(12)), - pc.add(pc.month(v), pa.scalar(-1)), - ) - - elif isinstance(source, (TimestampType, TimestamptzType)): - - def month_func(v: Any) -> int: - return pc.add( - pc.multiply(pc.years_between(pa.scalar(datetime(1970, 1, 1)), pc.local_timestamp(v)), pa.scalar(12)), - pc.add(pc.month(v), pa.scalar(-1)), - ) + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ else: raise ValueError(f"Cannot apply month transform for type: {source}") + def month_func(v: pa.Array) -> pa.Array: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + return lambda v: month_func(v) if v is not None else None @@ -503,6 +517,21 @@ def __repr__(self) -> str: """Return the string representation of the DayTransform class.""" return "DayTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply day transform for type: {source}") + + return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None + class HourTransform(TimeTransform[S]): """Transforms a datetime value into a hour value. @@ -540,6 +569,19 @@ def __repr__(self) -> str: """Return the string representation of the HourTransform class.""" return "HourTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply month transform for type: {source}") + + return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None + def _base64encode(buffer: bytes) -> str: """Convert bytes to base64 string.""" diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 91856a4790..07c3c43a2c 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -16,11 +16,11 @@ # under the License. # pylint:disable=redefined-outer-name -from datetime import date, datetime, timezone import pyarrow as pa import pytest from pyspark.sql import SparkSession +from typing import Any from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError @@ -31,6 +31,7 @@ HourTransform, IdentityTransform, MonthTransform, + Transform, TruncateTransform, YearTransform, ) @@ -38,54 +39,6 @@ from utils import TABLE_SCHEMA, _create_table -@pytest.fixture(scope="session") -def arrow_table_dates() -> pa.Table: - """Pyarrow table with only null values.""" - TEST_DATES = [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1)] - return pa.Table.from_pydict( - {"dates": TEST_DATES}, - schema=pa.schema([ - ("dates", pa.date32()), - ]), - ) - - -@pytest.fixture(scope="session") -def arrow_table_timestamp() -> pa.Table: - """Pyarrow table with only null values.""" - TEST_DATETIMES = [ - datetime(2023, 12, 31, 0, 0, 0), - datetime(2024, 1, 1, 0, 0, 0), - datetime(2024, 1, 31, 0, 0, 0), - datetime(2024, 2, 1, 0, 0, 0), - datetime(2024, 2, 1, 6, 0, 0), - ] - return pa.Table.from_pydict( - {"dates": TEST_DATETIMES}, - schema=pa.schema([ - ("timestamp", pa.timestamp(unit="us")), - ]), - ) - - -@pytest.fixture(scope="session") -def arrow_table_timestamptz() -> pa.Table: - """Pyarrow table with only null values.""" - TEST_DATETIMES_WITH_TZ = [ - datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), - ] - return pa.Table.from_pydict( - {"dates": TEST_DATETIMES_WITH_TZ}, - schema=pa.schema([ - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ]), - ) - - @pytest.mark.integration @pytest.mark.parametrize( "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] @@ -437,18 +390,19 @@ def test_unsupported_transform( @pytest.mark.integration +@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) @pytest.mark.parametrize( - "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] + "part_col", ["date", "timestamp", "timestamptz"] ) @pytest.mark.parametrize("format_version", [1, 2]) -def test_append_time_transform_partitioned_table( - session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +def test_append_ymd_transform_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, transform: Transform[Any, Any], part_col: str, format_version: int ) -> None: # Given - identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}" nested_field = TABLE_SCHEMA.find_field(part_col) partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) ) # When @@ -456,20 +410,14 @@ def test_append_time_transform_partitioned_table( session_catalog=session_catalog, identifier=identifier, properties={"format-version": str(format_version)}, - data=[], + data=[arrow_table_with_null], partition_spec=partition_spec, ) - # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] - tbl.append(arrow_table_with_null) - tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) # Then assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" for col in TEST_DATA_WITH_NULL.keys(): - df = spark.table(identifier) - assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" - assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" - # expecting 6 files: first append with [A], [B], [C], second append with [A, A], [B, B], [C, C] - rows = spark.sql(f"select partition from {identifier}.files").collect() - assert len(rows) == 6 + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" \ No newline at end of file diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4dc3d9819f..1f3c47a8d9 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=eval-used,protected-access,redefined-outer-name -from datetime import date +from datetime import date, datetime, timezone from decimal import Decimal -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import UUID import mmh3 as mmh3 @@ -69,6 +69,7 @@ TimestampLiteral, literal, ) +from pyiceberg.partitioning import _to_partition_representation from pyiceberg.schema import Accessor from pyiceberg.transforms import ( BucketTransform, @@ -111,6 +112,9 @@ timestamptz_to_micros, ) +if TYPE_CHECKING: + import pyarrow as pa + @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1808,3 +1812,66 @@ def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: _test_projection( lhs=transform.strict_project(name="name", pred=BoundIn(term=bound_reference_binary, literals=set_of_literals)), rhs=None ) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps() -> "pa.Table": + """Pyarrow table with only date, timestamp and timestamptz values.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], + "timestamp": [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + None, + ], + "timestamptz": [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + None, + ], + }, + schema=pa.schema([ + ("date", pa.date32()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + +@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) +@pytest.mark.parametrize( + "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] +) +def test_ymd_pyarrow_transforms( + arrow_table_date_timestamps: "pa.Table", + source_col: str, + source_type: PrimitiveType, + transform: Transform[Any, Any], +) -> None: + assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + transform.transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + + +@pytest.mark.parametrize("source_col, source_type", [("timestamp", TimestampType()), ("timestamptz", TimestamptzType())]) +def test_hour_pyarrow_transforms(arrow_table_date_timestamps: "pa.Table", source_col: str, source_type: PrimitiveType) -> None: + assert HourTransform().pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + HourTransform().transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + + +def test_hour_pyarrow_transforms_throws_with_dates(arrow_table_date_timestamps: "pa.Table") -> None: + # HourTransform is not supported for DateType + with pytest.raises(ValueError): + HourTransform().pyarrow_transform(DateType())(arrow_table_date_timestamps["date"]) From ddfa9ac2af145f2c5d7c4a3b841c220c10fc280e Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 5 May 2024 19:51:42 +0000 Subject: [PATCH 03/21] todo: sort with pyarrow_transform vals --- pyiceberg/table/__init__.py | 7 +++-- pyiceberg/transforms.py | 18 +++++++++++ .../test_writes/test_partitioned_writes.py | 31 ++++++++++++++++--- tests/test_transforms.py | 28 ++++++----------- 4 files changed, 57 insertions(+), 27 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 13186c42cc..85c3c3360b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -381,10 +381,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - supported_transforms = {IdentityTransform} - if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields): + if unsupported_partitions := [ + field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform + ]: raise ValueError( - f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}." + f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) _check_schema_compatible(self._table.schema(), other_schema=df.schema) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 0cf26fe2a2..c8af97c301 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -178,6 +178,10 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root return False + @property + def supports_pyarrow_transform(self) -> bool: + return False + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -352,6 +356,13 @@ def dedup_name(self) -> str: def preserves_order(self) -> bool: return True + @abstractmethod + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + + @property + def supports_pyarrow_transform(self) -> bool: + return True + class YearTransform(TimeTransform[S]): """Transforms a datetime value into a year value. @@ -652,6 +663,13 @@ def __repr__(self) -> str: """Return the string representation of the IdentityTransform class.""" return "IdentityTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + return lambda v: v + + @property + def supports_pyarrow_transform(self) -> bool: + return True + class TruncateTransform(Transform[S, S]): """A transform for truncating a value to a specified width. diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 07c3c43a2c..62c241b0eb 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -17,10 +17,11 @@ # pylint:disable=redefined-outer-name +from typing import Any + import pyarrow as pa import pytest from pyspark.sql import SparkSession -from typing import Any from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError @@ -390,13 +391,24 @@ def test_unsupported_transform( @pytest.mark.integration -@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) @pytest.mark.parametrize( - "part_col", ["date", "timestamp", "timestamptz"] + "transform,expected_rows", + [ + pytest.param(YearTransform(), 2, id="year_transform"), + pytest.param(MonthTransform(), 3, id="month_transform"), + pytest.param(DayTransform(), 3, id="day_transform"), + ], ) +@pytest.mark.parametrize("part_col", ["date", "timestamp", "timestamptz"]) @pytest.mark.parametrize("format_version", [1, 2]) def test_append_ymd_transform_partitioned( - session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, transform: Transform[Any, Any], part_col: str, format_version: int + session_catalog: Catalog, + spark: SparkSession, + arrow_table_with_null: pa.Table, + transform: Transform[Any, Any], + expected_rows: int, + part_col: str, + format_version: int, ) -> None: # Given identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}" @@ -420,4 +432,13 @@ def test_append_ymd_transform_partitioned( assert df.count() == 3, f"Expected 3 total rows for {identifier}" for col in TEST_DATA_WITH_NULL.keys(): assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" - assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" \ No newline at end of file + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1f3c47a8d9..4f926e4fb4 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1847,7 +1847,7 @@ def arrow_table_date_timestamps() -> "pa.Table": ) -@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) +@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform(), HourTransform()]) @pytest.mark.parametrize( "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] ) @@ -1857,21 +1857,11 @@ def test_ymd_pyarrow_transforms( source_type: PrimitiveType, transform: Transform[Any, Any], ) -> None: - assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ - transform.transform(source_type)(_to_partition_representation(source_type, v)) - for v in arrow_table_date_timestamps[source_col].to_pylist() - ] - - -@pytest.mark.parametrize("source_col, source_type", [("timestamp", TimestampType()), ("timestamptz", TimestamptzType())]) -def test_hour_pyarrow_transforms(arrow_table_date_timestamps: "pa.Table", source_col: str, source_type: PrimitiveType) -> None: - assert HourTransform().pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ - HourTransform().transform(source_type)(_to_partition_representation(source_type, v)) - for v in arrow_table_date_timestamps[source_col].to_pylist() - ] - - -def test_hour_pyarrow_transforms_throws_with_dates(arrow_table_date_timestamps: "pa.Table") -> None: - # HourTransform is not supported for DateType - with pytest.raises(ValueError): - HourTransform().pyarrow_transform(DateType())(arrow_table_date_timestamps["date"]) + if transform.can_transform(source_type): + assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + transform.transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + else: + with pytest.raises(ValueError): + transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) From 1a5327a5f6436d631fa21bcc9ec15d510274457e Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 02:05:44 +0000 Subject: [PATCH 04/21] checkpoint --- pyiceberg/table/__init__.py | 59 +++++++++++++++---------------------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 85c3c3360b..a65fbaa5ca 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3545,33 +3545,6 @@ class TablePartition: arrow_table_partition: pa.Table -def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: - order = 'ascending' if not reverse else 'descending' - null_placement = 'at_start' if reverse else 'at_end' - return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement} - - -def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table: - """Given a table, sort it by current partition scheme.""" - # only works for identity for now - sort_options = _get_partition_sort_order(partition_columns, reverse=False) - sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement']) - return sorted_arrow_table - - -def get_partition_columns( - spec: PartitionSpec, - schema: Schema, -) -> list[str]: - partition_cols = [] - for partition_field in spec.fields: - column_name = schema.find_column_name(partition_field.source_id) - if not column_name: - raise ValueError(f"{partition_field=} could not be found in {schema}.") - partition_cols.append(column_name) - return partition_cols - - def _get_table_partitions( arrow_table: pa.Table, partition_spec: PartitionSpec, @@ -3626,13 +3599,29 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T """ import pyarrow as pa - partition_columns = get_partition_columns(spec=spec, schema=schema) - arrow_table = group_by_partition_scheme(arrow_table, partition_columns) - - reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) - reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() - - slice_instructions: list[dict[str, Any]] = [] + partition_columns: List[Tuple[PartitionField, NestedField]] = [ + (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields + ] + partition_values_table = pa.table({ + str(partition.field_id): partition.pyarrow_transform(field.field_type)(arrow_table[field.name]) + for partition, field in partition_columns + }) + + # Sort by partitions + sort_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "ascending") for col in partition_values_table.column_names], + null_placement="at_end", + ).to_pylist() + arrow_table = arrow_table.take(sort_indices) + + # Get slice_instructions to group by partitions + reversed_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "descending") for col in partition_values_table.column_names], + null_placement="at_start", + ).to_pylist() + slice_instructions: List[Dict[str, Any]] = [] last = len(reversed_indices) reversed_indices_size = len(reversed_indices) ptr = 0 @@ -3643,6 +3632,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T last = reversed_indices[ptr] ptr = ptr + group_size - table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) + table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) return table_partitions From e067a28a286f11005710f42013f88c4e13c9dbeb Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 02:09:32 +0000 Subject: [PATCH 05/21] checkpoint --- pyiceberg/table/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index a65fbaa5ca..e580eb5990 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3603,7 +3603,7 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields ] partition_values_table = pa.table({ - str(partition.field_id): partition.pyarrow_transform(field.field_type)(arrow_table[field.name]) + str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) for partition, field in partition_columns }) From 069f3bd9e7e1f8812ee67acc21547ca33a04f81a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 02:17:31 +0000 Subject: [PATCH 06/21] fix --- pyiceberg/table/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e580eb5990..6937673f44 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3616,6 +3616,7 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T arrow_table = arrow_table.take(sort_indices) # Get slice_instructions to group by partitions + partition_values_table = partition_values_table.take(sort_indices) reversed_indices = pa.compute.sort_indices( partition_values_table, sort_keys=[(col, "descending") for col in partition_values_table.column_names], From 615d5e397b3a2fc4b99b69a2f8a0781f182b4d99 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 14:42:47 +0000 Subject: [PATCH 07/21] tests --- tests/conftest.py | 43 +++++++++++ .../test_writes/test_partitioned_writes.py | 76 +++++++++++++++---- tests/test_transforms.py | 45 +++-------- 3 files changed, 115 insertions(+), 49 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6679543694..1afdcae4bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2145,3 +2145,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table": import pyarrow as pa return pa.Table.from_pylist([{}, {}], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps() -> "pa.Table": + """Pyarrow table with only date, timestamp and timestamptz values.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], + "timestamp": [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + None, + ], + "timestamptz": [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + None, + ], + }, + schema=pa.schema([ + ("date", pa.date32()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps_schema() -> Schema: + """Pyarrow table Schema with only date, timestamp and timestamptz values.""" + return Schema( + NestedField(field_id=1, name="date", field_type=DateType(), required=False), + NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False), + ) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 62c241b0eb..acdb5cb7b6 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -26,6 +26,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, DayTransform, @@ -355,18 +356,6 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=YearTransform(), name="timestamp_year"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=YearTransform(), name="timestamptz_year"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=YearTransform(), name="date_year"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="timestamp_month"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=MonthTransform(), name="timestamptz_month"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="date_month"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=DayTransform(), name="timestamp_day"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=DayTransform(), name="timestamptz_day"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=DayTransform(), name="date_day"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=HourTransform(), name="timestamp_hour"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=HourTransform(), name="timestamptz_hour"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=HourTransform(), name="date_hour"))), ], ) def test_unsupported_transform( @@ -386,7 +375,10 @@ def test_unsupported_transform( properties={'format-version': '1'}, ) - with pytest.raises(ValueError, match="All transforms are not supported.*"): + with pytest.raises( + ValueError, + match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", + ): tbl.append(arrow_table_with_null) @@ -411,7 +403,7 @@ def test_append_ymd_transform_partitioned( format_version: int, ) -> None: # Given - identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}" + identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}" nested_field = TABLE_SCHEMA.find_field(part_col) partition_spec = PartitionSpec( PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) @@ -442,3 +434,59 @@ def test_append_ymd_transform_partitioned( """ ) assert files_df.count() == expected_rows + + +@pytest.mark.integration +@pytest.mark.parametrize( + "transform,expected_partitions", + [ + pytest.param(YearTransform(), 3, id="year_transform"), + pytest.param(MonthTransform(), 4, id="month_transform"), + pytest.param(DayTransform(), 5, id="day_transform"), + pytest.param(HourTransform(), 6, id="hour_transform"), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_transform_partition_verify_partitions_count( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_date_timestamps: pa.Table, + arrow_table_date_timestamps_schema: Schema, + transform: Transform[Any, Any], + expected_partitions: int, + format_version: int, +) -> None: + # Given + part_col = "timestamptz" + identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" + nested_field = arrow_table_date_timestamps_schema.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_date_timestamps], + partition_spec=partition_spec, + schema=arrow_table_date_timestamps_schema, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 6, f"Expected 6 total rows for {identifier}" + for col in arrow_table_date_timestamps.column_names: + assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_partitions + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_partitions diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4f926e4fb4..d86817a310 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=eval-used,protected-access,redefined-outer-name -from datetime import date, datetime, timezone +from datetime import date from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import UUID @@ -1814,40 +1814,15 @@ def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: ) -@pytest.fixture(scope="session") -def arrow_table_date_timestamps() -> "pa.Table": - """Pyarrow table with only date, timestamp and timestamptz values.""" - import pyarrow as pa - - return pa.Table.from_pydict( - { - "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], - "timestamp": [ - datetime(2023, 12, 31, 0, 0, 0), - datetime(2024, 1, 1, 0, 0, 0), - datetime(2024, 1, 31, 0, 0, 0), - datetime(2024, 2, 1, 0, 0, 0), - datetime(2024, 2, 1, 6, 0, 0), - None, - ], - "timestamptz": [ - datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), - None, - ], - }, - schema=pa.schema([ - ("date", pa.date32()), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ]), - ) - - -@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform(), HourTransform()]) +@pytest.mark.parametrize( + 'transform', + [ + pytest.param(YearTransform(), id="year_transform"), + pytest.param(MonthTransform(), id="month_transform"), + pytest.param(DayTransform(), id="day_transform"), + pytest.param(HourTransform(), id="hour_transform"), + ], +) @pytest.mark.parametrize( "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] ) From c0a0f321793850f598bd221405bbbda0e3337610 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 16:20:19 +0000 Subject: [PATCH 08/21] more tests --- Makefile | 2 +- pyiceberg/partitioning.py | 2 +- .../test_writes/test_partitioned_writes.py | 87 +++++++++++++++++-- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 35051be9c1..de50374cfb 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ test-integration: sleep 10 docker compose -f dev/docker-compose-integration.yml cp ./dev/provision.py spark-iceberg:/opt/spark/provision.py docker compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py - poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} + poetry run pytest tests/integration/test_writes/test_partitioned_writes.py -v -m integration ${PYTEST_ARGS} test-integration-rebuild: docker compose -f dev/docker-compose-integration.yml kill diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index a3cf255341..a3b482181d 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -387,7 +387,7 @@ def partition(self) -> Record: # partition key transformed with iceberg interna for raw_partition_field_value in self.raw_partition_field_values: partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id] if len(partition_fields) != 1: - raise ValueError("partition_fields must contain exactly one field.") + raise ValueError(f"Cannot have redundant partitions: {partition_fields}") partition_field = partition_fields[0] iceberg_typed_key_values[partition_field.name] = partition_record_value( partition_field=partition_field, diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index acdb5cb7b6..97960cd536 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -17,7 +17,8 @@ # pylint:disable=redefined-outer-name -from typing import Any +from datetime import date +from typing import Any, Set import pyarrow as pa import pytest @@ -440,10 +441,12 @@ def test_append_ymd_transform_partitioned( @pytest.mark.parametrize( "transform,expected_partitions", [ - pytest.param(YearTransform(), 3, id="year_transform"), - pytest.param(MonthTransform(), 4, id="month_transform"), - pytest.param(DayTransform(), 5, id="day_transform"), - pytest.param(HourTransform(), 6, id="hour_transform"), + pytest.param(YearTransform(), {53, 54, None}, id="year_transform"), + pytest.param(MonthTransform(), {647, 648, 649, None}, id="month_transform"), + pytest.param( + DayTransform(), {date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), None}, id="day_transform" + ), + pytest.param(HourTransform(), {473328, 473352, 474072, 474096, 474102, None}, id="hour_transform"), ], ) @pytest.mark.parametrize("format_version", [1, 2]) @@ -453,7 +456,7 @@ def test_append_transform_partition_verify_partitions_count( arrow_table_date_timestamps: pa.Table, arrow_table_date_timestamps_schema: Schema, transform: Transform[Any, Any], - expected_partitions: int, + expected_partitions: Set[Any], format_version: int, ) -> None: # Given @@ -461,7 +464,7 @@ def test_append_transform_partition_verify_partitions_count( identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" nested_field = arrow_table_date_timestamps_schema.find_field(part_col) partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), ) # When @@ -482,11 +485,77 @@ def test_append_transform_partition_verify_partitions_count( assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" - assert tbl.inspect.partitions().num_rows == expected_partitions + partitions_table = tbl.inspect.partitions() + assert partitions_table.num_rows == len(expected_partitions) + assert {part[part_col] for part in partitions_table['partition'].to_pylist()} == expected_partitions + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == len(expected_partitions) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_multiple_partitions( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_date_timestamps: pa.Table, + arrow_table_date_timestamps_schema: Schema, + format_version: int, +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions" + partition_spec = PartitionSpec( + PartitionField( + source_id=arrow_table_date_timestamps_schema.find_field("date").field_id, + field_id=1001, + transform=YearTransform(), + name="date_year", + ), + PartitionField( + source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id, + field_id=1000, + transform=HourTransform(), + name="timestamptz_hour", + ), + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_date_timestamps], + partition_spec=partition_spec, + schema=arrow_table_date_timestamps_schema, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 6, f"Expected 6 total rows for {identifier}" + for col in arrow_table_date_timestamps.column_names: + assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + partitions_table = tbl.inspect.partitions() + assert partitions_table.num_rows == 6 + partitions = partitions_table['partition'].to_pylist() + assert {(part["date_year"], part["timestamptz_hour"]) for part in partitions} == { + (53, 473328), + (54, 473352), + (54, 474072), + (54, 474096), + (54, 474102), + (None, None), + } files_df = spark.sql( f""" SELECT * FROM {identifier}.files """ ) - assert files_df.count() == expected_partitions + assert files_df.count() == 6 From 6a39eda3ea0992b591757f857b57c117d0b36e0b Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Tue, 7 May 2024 13:24:08 +0000 Subject: [PATCH 09/21] adopt review feedback --- Makefile | 2 +- pyiceberg/transforms.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index de50374cfb..35051be9c1 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ test-integration: sleep 10 docker compose -f dev/docker-compose-integration.yml cp ./dev/provision.py spark-iceberg:/opt/spark/provision.py docker compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py - poetry run pytest tests/integration/test_writes/test_partitioned_writes.py -v -m integration ${PYTEST_ARGS} + poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} test-integration-rebuild: docker compose -f dev/docker-compose-integration.yml kill diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index c8af97c301..f4d0640d43 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -182,6 +182,9 @@ def __eq__(self, other: Any) -> bool: def supports_pyarrow_transform(self) -> bool: return False + @abstractmethod + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -297,6 +300,9 @@ def __repr__(self) -> str: """Return the string representation of the BucketTransform class.""" return f"BucketTransform(num_buckets={self._num_buckets})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + class TimeResolution(IntEnum): YEAR = 6 @@ -356,9 +362,6 @@ def dedup_name(self) -> str: def preserves_order(self) -> bool: return True - @abstractmethod - def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... - @property def supports_pyarrow_transform(self) -> bool: return True @@ -810,6 +813,9 @@ def __repr__(self) -> str: """Return the string representation of the TruncateTransform class.""" return f"TruncateTransform(width={self._width})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + @singledispatch def _human_string(value: Any, _type: IcebergType) -> str: @@ -892,6 +898,9 @@ def __repr__(self) -> str: """Return the string representation of the UnknownTransform class.""" return f"UnknownTransform(transform={repr(self._transform)})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + class VoidTransform(Transform[S, None], Singleton): """A transform that always returns None.""" @@ -920,6 +929,9 @@ def __repr__(self) -> str: """Return the string representation of the VoidTransform class.""" return "VoidTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + def _truncate_number( name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] From d14e137b5fabcb05de8b58de4f42591674d98e48 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Wed, 8 May 2024 22:12:57 +0000 Subject: [PATCH 10/21] comment --- pyiceberg/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index f4d0640d43..38cc6221a2 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -592,7 +592,7 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr elif isinstance(source, TimestamptzType): epoch = datetime.EPOCH_TIMESTAMPTZ else: - raise ValueError(f"Cannot apply month transform for type: {source}") + raise ValueError(f"Cannot apply hour transform for type: {source}") return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None From 5dd846d729662522409f0b06da923daefe7dfd97 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 4 May 2024 02:20:39 +0000 Subject: [PATCH 11/21] checkpoint --- pyiceberg/transforms.py | 25 ++++++ .../test_writes/test_partitioned_writes.py | 89 +++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 6dcae59e49..c75f7861c0 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -433,6 +433,31 @@ def __repr__(self) -> str: """Return the string representation of the MonthTransform class.""" return "MonthTransform()" + def pyarrow_transform(self, source: IcebergType) -> Callable: + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + + def month_func(v: Any) -> int: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(date(1970, 1, 1)), v), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + + elif isinstance(source, (TimestampType, TimestamptzType)): + + def month_func(v: Any) -> int: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(datetime(1970, 1, 1)), pc.local_timestamp(v)), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + + else: + raise ValueError(f"Cannot apply month transform for type: {source}") + + return lambda v: month_func(v) if v is not None else None + class DayTransform(TimeTransform[S]): """Transforms a datetime value into a day value. diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 5cb03e59d8..ddfb6b0f1d 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -16,6 +16,8 @@ # under the License. # pylint:disable=redefined-outer-name +from datetime import date, datetime, timezone + import pyarrow as pa import pytest from pyspark.sql import SparkSession @@ -36,6 +38,54 @@ from utils import TABLE_SCHEMA, _create_table +@pytest.fixture(scope="session") +def arrow_table_dates() -> pa.Table: + """Pyarrow table with only null values.""" + TEST_DATES = [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1)] + return pa.Table.from_pydict( + {"dates": TEST_DATES}, + schema=pa.schema([ + ("dates", pa.date32()), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_timestamp() -> pa.Table: + """Pyarrow table with only null values.""" + TEST_DATETIMES = [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + ] + return pa.Table.from_pydict( + {"dates": TEST_DATETIMES}, + schema=pa.schema([ + ("timestamp", pa.timestamp(unit="us")), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_timestamptz() -> pa.Table: + """Pyarrow table with only null values.""" + TEST_DATETIMES_WITH_TZ = [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + ] + return pa.Table.from_pydict( + {"dates": TEST_DATETIMES_WITH_TZ}, + schema=pa.schema([ + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + @pytest.mark.integration @pytest.mark.parametrize( "part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"] @@ -384,3 +434,42 @@ def test_unsupported_transform( with pytest.raises(ValueError, match="All transforms are not supported.*"): tbl.append(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_time_transform_partitioned_table( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[], + partition_spec=partition_spec, + ) + # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + df = spark.table(identifier) + assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" + # expecting 6 files: first append with [A], [B], [C], second append with [A, A], [B, B], [C, C] + rows = spark.sql(f"select partition from {identifier}.files").collect() + assert len(rows) == 6 From 6357193ee314e7443c4d9599856bb7a8fe3716fd Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 5 May 2024 16:27:57 +0000 Subject: [PATCH 12/21] checkpoint2 --- pyiceberg/transforms.py | 78 ++++++++++++++----- .../test_writes/test_partitioned_writes.py | 76 +++--------------- tests/test_transforms.py | 71 ++++++++++++++++- 3 files changed, 141 insertions(+), 84 deletions(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index c75f7861c0..0cf26fe2a2 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from enum import IntEnum from functools import singledispatch -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar from typing import Literal as LiteralType from uuid import UUID @@ -82,6 +82,9 @@ from pyiceberg.utils.parsing import ParseNumberFromBrackets from pyiceberg.utils.singleton import Singleton +if TYPE_CHECKING: + import pyarrow as pa + S = TypeVar("S") T = TypeVar("T") @@ -391,6 +394,21 @@ def __repr__(self) -> str: """Return the string representation of the YearTransform class.""" return "YearTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply year transform for type: {source}") + + return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None + class MonthTransform(TimeTransform[S]): """Transforms a datetime value into a month value. @@ -433,29 +451,25 @@ def __repr__(self) -> str: """Return the string representation of the MonthTransform class.""" return "MonthTransform()" - def pyarrow_transform(self, source: IcebergType) -> Callable: + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": import pyarrow as pa import pyarrow.compute as pc - - if isinstance(source, DateType): - - def month_func(v: Any) -> int: - return pc.add( - pc.multiply(pc.years_between(pa.scalar(date(1970, 1, 1)), v), pa.scalar(12)), - pc.add(pc.month(v), pa.scalar(-1)), - ) - - elif isinstance(source, (TimestampType, TimestamptzType)): - - def month_func(v: Any) -> int: - return pc.add( - pc.multiply(pc.years_between(pa.scalar(datetime(1970, 1, 1)), pc.local_timestamp(v)), pa.scalar(12)), - pc.add(pc.month(v), pa.scalar(-1)), - ) + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ else: raise ValueError(f"Cannot apply month transform for type: {source}") + def month_func(v: pa.Array) -> pa.Array: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + return lambda v: month_func(v) if v is not None else None @@ -503,6 +517,21 @@ def __repr__(self) -> str: """Return the string representation of the DayTransform class.""" return "DayTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply day transform for type: {source}") + + return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None + class HourTransform(TimeTransform[S]): """Transforms a datetime value into a hour value. @@ -540,6 +569,19 @@ def __repr__(self) -> str: """Return the string representation of the HourTransform class.""" return "HourTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply month transform for type: {source}") + + return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None + def _base64encode(buffer: bytes) -> str: """Convert bytes to base64 string.""" diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index ddfb6b0f1d..f8335274ab 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -16,11 +16,11 @@ # under the License. # pylint:disable=redefined-outer-name -from datetime import date, datetime, timezone import pyarrow as pa import pytest from pyspark.sql import SparkSession +from typing import Any from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError @@ -31,6 +31,7 @@ HourTransform, IdentityTransform, MonthTransform, + Transform, TruncateTransform, YearTransform, ) @@ -38,54 +39,6 @@ from utils import TABLE_SCHEMA, _create_table -@pytest.fixture(scope="session") -def arrow_table_dates() -> pa.Table: - """Pyarrow table with only null values.""" - TEST_DATES = [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1)] - return pa.Table.from_pydict( - {"dates": TEST_DATES}, - schema=pa.schema([ - ("dates", pa.date32()), - ]), - ) - - -@pytest.fixture(scope="session") -def arrow_table_timestamp() -> pa.Table: - """Pyarrow table with only null values.""" - TEST_DATETIMES = [ - datetime(2023, 12, 31, 0, 0, 0), - datetime(2024, 1, 1, 0, 0, 0), - datetime(2024, 1, 31, 0, 0, 0), - datetime(2024, 2, 1, 0, 0, 0), - datetime(2024, 2, 1, 6, 0, 0), - ] - return pa.Table.from_pydict( - {"dates": TEST_DATETIMES}, - schema=pa.schema([ - ("timestamp", pa.timestamp(unit="us")), - ]), - ) - - -@pytest.fixture(scope="session") -def arrow_table_timestamptz() -> pa.Table: - """Pyarrow table with only null values.""" - TEST_DATETIMES_WITH_TZ = [ - datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), - ] - return pa.Table.from_pydict( - {"dates": TEST_DATETIMES_WITH_TZ}, - schema=pa.schema([ - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ]), - ) - - @pytest.mark.integration @pytest.mark.parametrize( "part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"] @@ -437,18 +390,19 @@ def test_unsupported_transform( @pytest.mark.integration +@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) @pytest.mark.parametrize( - "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] + "part_col", ["date", "timestamp", "timestamptz"] ) @pytest.mark.parametrize("format_version", [1, 2]) -def test_append_time_transform_partitioned_table( - session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +def test_append_ymd_transform_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, transform: Transform[Any, Any], part_col: str, format_version: int ) -> None: # Given - identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}" nested_field = TABLE_SCHEMA.find_field(part_col) partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) ) # When @@ -456,20 +410,14 @@ def test_append_time_transform_partitioned_table( session_catalog=session_catalog, identifier=identifier, properties={"format-version": str(format_version)}, - data=[], + data=[arrow_table_with_null], partition_spec=partition_spec, ) - # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] - tbl.append(arrow_table_with_null) - tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) # Then assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" for col in TEST_DATA_WITH_NULL.keys(): - df = spark.table(identifier) - assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" - assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" - # expecting 6 files: first append with [A], [B], [C], second append with [A, A], [B, B], [C, C] - rows = spark.sql(f"select partition from {identifier}.files").collect() - assert len(rows) == 6 + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" \ No newline at end of file diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b8bef4b998..4a1e066b1e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=eval-used,protected-access,redefined-outer-name -from datetime import date +from datetime import date, datetime, timezone from decimal import Decimal -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import UUID import mmh3 as mmh3 @@ -69,6 +69,7 @@ TimestampLiteral, literal, ) +from pyiceberg.partitioning import _to_partition_representation from pyiceberg.schema import Accessor from pyiceberg.transforms import ( BucketTransform, @@ -111,6 +112,9 @@ timestamptz_to_micros, ) +if TYPE_CHECKING: + import pyarrow as pa + @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1808,3 +1812,66 @@ def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: _test_projection( lhs=transform.strict_project(name="name", pred=BoundIn(term=bound_reference_binary, literals=set_of_literals)), rhs=None ) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps() -> "pa.Table": + """Pyarrow table with only date, timestamp and timestamptz values.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], + "timestamp": [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + None, + ], + "timestamptz": [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + None, + ], + }, + schema=pa.schema([ + ("date", pa.date32()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + +@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) +@pytest.mark.parametrize( + "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] +) +def test_ymd_pyarrow_transforms( + arrow_table_date_timestamps: "pa.Table", + source_col: str, + source_type: PrimitiveType, + transform: Transform[Any, Any], +) -> None: + assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + transform.transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + + +@pytest.mark.parametrize("source_col, source_type", [("timestamp", TimestampType()), ("timestamptz", TimestamptzType())]) +def test_hour_pyarrow_transforms(arrow_table_date_timestamps: "pa.Table", source_col: str, source_type: PrimitiveType) -> None: + assert HourTransform().pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + HourTransform().transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + + +def test_hour_pyarrow_transforms_throws_with_dates(arrow_table_date_timestamps: "pa.Table") -> None: + # HourTransform is not supported for DateType + with pytest.raises(ValueError): + HourTransform().pyarrow_transform(DateType())(arrow_table_date_timestamps["date"]) From c30a57cfe93aaf979df949f801929ecf10079601 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 5 May 2024 19:51:42 +0000 Subject: [PATCH 13/21] todo: sort with pyarrow_transform vals --- pyiceberg/table/__init__.py | 7 +++-- pyiceberg/transforms.py | 18 +++++++++++ .../test_writes/test_partitioned_writes.py | 31 ++++++++++++++++--- tests/test_transforms.py | 28 ++++++----------- 4 files changed, 57 insertions(+), 27 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index aa108de08b..ea88312368 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -392,10 +392,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - supported_transforms = {IdentityTransform} - if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields): + if unsupported_partitions := [ + field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform + ]: raise ValueError( - f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}." + f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) _check_schema_compatible(self._table.schema(), other_schema=df.schema) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 0cf26fe2a2..c8af97c301 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -178,6 +178,10 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root return False + @property + def supports_pyarrow_transform(self) -> bool: + return False + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -352,6 +356,13 @@ def dedup_name(self) -> str: def preserves_order(self) -> bool: return True + @abstractmethod + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + + @property + def supports_pyarrow_transform(self) -> bool: + return True + class YearTransform(TimeTransform[S]): """Transforms a datetime value into a year value. @@ -652,6 +663,13 @@ def __repr__(self) -> str: """Return the string representation of the IdentityTransform class.""" return "IdentityTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + return lambda v: v + + @property + def supports_pyarrow_transform(self) -> bool: + return True + class TruncateTransform(Transform[S, S]): """A transform for truncating a value to a specified width. diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index f8335274ab..3a0e38d3f2 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -17,10 +17,11 @@ # pylint:disable=redefined-outer-name +from typing import Any + import pyarrow as pa import pytest from pyspark.sql import SparkSession -from typing import Any from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError @@ -390,13 +391,24 @@ def test_unsupported_transform( @pytest.mark.integration -@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) @pytest.mark.parametrize( - "part_col", ["date", "timestamp", "timestamptz"] + "transform,expected_rows", + [ + pytest.param(YearTransform(), 2, id="year_transform"), + pytest.param(MonthTransform(), 3, id="month_transform"), + pytest.param(DayTransform(), 3, id="day_transform"), + ], ) +@pytest.mark.parametrize("part_col", ["date", "timestamp", "timestamptz"]) @pytest.mark.parametrize("format_version", [1, 2]) def test_append_ymd_transform_partitioned( - session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, transform: Transform[Any, Any], part_col: str, format_version: int + session_catalog: Catalog, + spark: SparkSession, + arrow_table_with_null: pa.Table, + transform: Transform[Any, Any], + expected_rows: int, + part_col: str, + format_version: int, ) -> None: # Given identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}" @@ -420,4 +432,13 @@ def test_append_ymd_transform_partitioned( assert df.count() == 3, f"Expected 3 total rows for {identifier}" for col in TEST_DATA_WITH_NULL.keys(): assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" - assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" \ No newline at end of file + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4a1e066b1e..3f1591c01c 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1847,7 +1847,7 @@ def arrow_table_date_timestamps() -> "pa.Table": ) -@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()]) +@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform(), HourTransform()]) @pytest.mark.parametrize( "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] ) @@ -1857,21 +1857,11 @@ def test_ymd_pyarrow_transforms( source_type: PrimitiveType, transform: Transform[Any, Any], ) -> None: - assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ - transform.transform(source_type)(_to_partition_representation(source_type, v)) - for v in arrow_table_date_timestamps[source_col].to_pylist() - ] - - -@pytest.mark.parametrize("source_col, source_type", [("timestamp", TimestampType()), ("timestamptz", TimestamptzType())]) -def test_hour_pyarrow_transforms(arrow_table_date_timestamps: "pa.Table", source_col: str, source_type: PrimitiveType) -> None: - assert HourTransform().pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ - HourTransform().transform(source_type)(_to_partition_representation(source_type, v)) - for v in arrow_table_date_timestamps[source_col].to_pylist() - ] - - -def test_hour_pyarrow_transforms_throws_with_dates(arrow_table_date_timestamps: "pa.Table") -> None: - # HourTransform is not supported for DateType - with pytest.raises(ValueError): - HourTransform().pyarrow_transform(DateType())(arrow_table_date_timestamps["date"]) + if transform.can_transform(source_type): + assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + transform.transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + else: + with pytest.raises(ValueError): + transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) From 541655f16940998420688529c639e8481d178c93 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 02:05:44 +0000 Subject: [PATCH 14/21] checkpoint --- pyiceberg/table/__init__.py | 59 +++++++++++++++---------------------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ea88312368..4040f9a616 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3644,33 +3644,6 @@ class TablePartition: arrow_table_partition: pa.Table -def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: - order = "ascending" if not reverse else "descending" - null_placement = "at_start" if reverse else "at_end" - return {"sort_keys": [(column_name, order) for column_name in partition_columns], "null_placement": null_placement} - - -def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table: - """Given a table, sort it by current partition scheme.""" - # only works for identity for now - sort_options = _get_partition_sort_order(partition_columns, reverse=False) - sorted_arrow_table = arrow_table.sort_by(sorting=sort_options["sort_keys"], null_placement=sort_options["null_placement"]) - return sorted_arrow_table - - -def get_partition_columns( - spec: PartitionSpec, - schema: Schema, -) -> list[str]: - partition_cols = [] - for partition_field in spec.fields: - column_name = schema.find_column_name(partition_field.source_id) - if not column_name: - raise ValueError(f"{partition_field=} could not be found in {schema}.") - partition_cols.append(column_name) - return partition_cols - - def _get_table_partitions( arrow_table: pa.Table, partition_spec: PartitionSpec, @@ -3725,13 +3698,29 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T """ import pyarrow as pa - partition_columns = get_partition_columns(spec=spec, schema=schema) - arrow_table = group_by_partition_scheme(arrow_table, partition_columns) - - reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) - reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() - - slice_instructions: list[dict[str, Any]] = [] + partition_columns: List[Tuple[PartitionField, NestedField]] = [ + (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields + ] + partition_values_table = pa.table({ + str(partition.field_id): partition.pyarrow_transform(field.field_type)(arrow_table[field.name]) + for partition, field in partition_columns + }) + + # Sort by partitions + sort_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "ascending") for col in partition_values_table.column_names], + null_placement="at_end", + ).to_pylist() + arrow_table = arrow_table.take(sort_indices) + + # Get slice_instructions to group by partitions + reversed_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "descending") for col in partition_values_table.column_names], + null_placement="at_start", + ).to_pylist() + slice_instructions: List[Dict[str, Any]] = [] last = len(reversed_indices) reversed_indices_size = len(reversed_indices) ptr = 0 @@ -3742,6 +3731,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T last = reversed_indices[ptr] ptr = ptr + group_size - table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) + table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) return table_partitions From afe83b177a74039218a18cd0e49a80ff0513de1c Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 02:09:32 +0000 Subject: [PATCH 15/21] checkpoint --- pyiceberg/table/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4040f9a616..16482108a6 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3702,7 +3702,7 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields ] partition_values_table = pa.table({ - str(partition.field_id): partition.pyarrow_transform(field.field_type)(arrow_table[field.name]) + str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) for partition, field in partition_columns }) From 00ca5f04b2281a82ae3fc869d5b8a2b3cdd0e2b2 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 02:17:31 +0000 Subject: [PATCH 16/21] fix --- pyiceberg/table/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 16482108a6..f160ab2441 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3715,6 +3715,7 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T arrow_table = arrow_table.take(sort_indices) # Get slice_instructions to group by partitions + partition_values_table = partition_values_table.take(sort_indices) reversed_indices = pa.compute.sort_indices( partition_values_table, sort_keys=[(col, "descending") for col in partition_values_table.column_names], From 511e98824aafc9c79d8c27ddefc3431797019ebf Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 14:42:47 +0000 Subject: [PATCH 17/21] tests --- tests/conftest.py | 43 +++++++++++ .../test_writes/test_partitioned_writes.py | 76 +++++++++++++++---- tests/test_transforms.py | 45 +++-------- 3 files changed, 115 insertions(+), 49 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 01915b7d82..d3f23689a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table": import pyarrow as pa return pa.Table.from_pylist([{}, {}], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps() -> "pa.Table": + """Pyarrow table with only date, timestamp and timestamptz values.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], + "timestamp": [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + None, + ], + "timestamptz": [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + None, + ], + }, + schema=pa.schema([ + ("date", pa.date32()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps_schema() -> Schema: + """Pyarrow table Schema with only date, timestamp and timestamptz values.""" + return Schema( + NestedField(field_id=1, name="date", field_type=DateType(), required=False), + NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False), + ) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 3a0e38d3f2..9df2ec218e 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -26,6 +26,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, DayTransform, @@ -355,18 +356,6 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=YearTransform(), name="timestamp_year"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=YearTransform(), name="timestamptz_year"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=YearTransform(), name="date_year"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="timestamp_month"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=MonthTransform(), name="timestamptz_month"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="date_month"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=DayTransform(), name="timestamp_day"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=DayTransform(), name="timestamptz_day"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=DayTransform(), name="date_day"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=HourTransform(), name="timestamp_hour"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=HourTransform(), name="timestamptz_hour"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=HourTransform(), name="date_hour"))), ], ) def test_unsupported_transform( @@ -386,7 +375,10 @@ def test_unsupported_transform( properties={"format-version": "1"}, ) - with pytest.raises(ValueError, match="All transforms are not supported.*"): + with pytest.raises( + ValueError, + match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", + ): tbl.append(arrow_table_with_null) @@ -411,7 +403,7 @@ def test_append_ymd_transform_partitioned( format_version: int, ) -> None: # Given - identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}" + identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}" nested_field = TABLE_SCHEMA.find_field(part_col) partition_spec = PartitionSpec( PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) @@ -442,3 +434,59 @@ def test_append_ymd_transform_partitioned( """ ) assert files_df.count() == expected_rows + + +@pytest.mark.integration +@pytest.mark.parametrize( + "transform,expected_partitions", + [ + pytest.param(YearTransform(), 3, id="year_transform"), + pytest.param(MonthTransform(), 4, id="month_transform"), + pytest.param(DayTransform(), 5, id="day_transform"), + pytest.param(HourTransform(), 6, id="hour_transform"), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_transform_partition_verify_partitions_count( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_date_timestamps: pa.Table, + arrow_table_date_timestamps_schema: Schema, + transform: Transform[Any, Any], + expected_partitions: int, + format_version: int, +) -> None: + # Given + part_col = "timestamptz" + identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" + nested_field = arrow_table_date_timestamps_schema.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_date_timestamps], + partition_spec=partition_spec, + schema=arrow_table_date_timestamps_schema, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 6, f"Expected 6 total rows for {identifier}" + for col in arrow_table_date_timestamps.column_names: + assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_partitions + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_partitions diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3f1591c01c..15ef7d0ea2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=eval-used,protected-access,redefined-outer-name -from datetime import date, datetime, timezone +from datetime import date from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import UUID @@ -1814,40 +1814,15 @@ def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: ) -@pytest.fixture(scope="session") -def arrow_table_date_timestamps() -> "pa.Table": - """Pyarrow table with only date, timestamp and timestamptz values.""" - import pyarrow as pa - - return pa.Table.from_pydict( - { - "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], - "timestamp": [ - datetime(2023, 12, 31, 0, 0, 0), - datetime(2024, 1, 1, 0, 0, 0), - datetime(2024, 1, 31, 0, 0, 0), - datetime(2024, 2, 1, 0, 0, 0), - datetime(2024, 2, 1, 6, 0, 0), - None, - ], - "timestamptz": [ - datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), - None, - ], - }, - schema=pa.schema([ - ("date", pa.date32()), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ]), - ) - - -@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform(), HourTransform()]) +@pytest.mark.parametrize( + 'transform', + [ + pytest.param(YearTransform(), id="year_transform"), + pytest.param(MonthTransform(), id="month_transform"), + pytest.param(DayTransform(), id="day_transform"), + pytest.param(HourTransform(), id="hour_transform"), + ], +) @pytest.mark.parametrize( "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] ) From 3b784abf2aeba1bad07581bbdd1bf5eba6efc5c3 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 6 May 2024 16:20:19 +0000 Subject: [PATCH 18/21] more tests --- Makefile | 2 +- pyiceberg/partitioning.py | 2 +- .../test_writes/test_partitioned_writes.py | 87 +++++++++++++++++-- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 35051be9c1..de50374cfb 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ test-integration: sleep 10 docker compose -f dev/docker-compose-integration.yml cp ./dev/provision.py spark-iceberg:/opt/spark/provision.py docker compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py - poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} + poetry run pytest tests/integration/test_writes/test_partitioned_writes.py -v -m integration ${PYTEST_ARGS} test-integration-rebuild: docker compose -f dev/docker-compose-integration.yml kill diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 481207db7a..da52d5df8e 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -387,7 +387,7 @@ def partition(self) -> Record: # partition key transformed with iceberg interna for raw_partition_field_value in self.raw_partition_field_values: partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id] if len(partition_fields) != 1: - raise ValueError("partition_fields must contain exactly one field.") + raise ValueError(f"Cannot have redundant partitions: {partition_fields}") partition_field = partition_fields[0] iceberg_typed_key_values[partition_field.name] = partition_record_value( partition_field=partition_field, diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 9df2ec218e..2f2aabc1fc 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -17,7 +17,8 @@ # pylint:disable=redefined-outer-name -from typing import Any +from datetime import date +from typing import Any, Set import pyarrow as pa import pytest @@ -440,10 +441,12 @@ def test_append_ymd_transform_partitioned( @pytest.mark.parametrize( "transform,expected_partitions", [ - pytest.param(YearTransform(), 3, id="year_transform"), - pytest.param(MonthTransform(), 4, id="month_transform"), - pytest.param(DayTransform(), 5, id="day_transform"), - pytest.param(HourTransform(), 6, id="hour_transform"), + pytest.param(YearTransform(), {53, 54, None}, id="year_transform"), + pytest.param(MonthTransform(), {647, 648, 649, None}, id="month_transform"), + pytest.param( + DayTransform(), {date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), None}, id="day_transform" + ), + pytest.param(HourTransform(), {473328, 473352, 474072, 474096, 474102, None}, id="hour_transform"), ], ) @pytest.mark.parametrize("format_version", [1, 2]) @@ -453,7 +456,7 @@ def test_append_transform_partition_verify_partitions_count( arrow_table_date_timestamps: pa.Table, arrow_table_date_timestamps_schema: Schema, transform: Transform[Any, Any], - expected_partitions: int, + expected_partitions: Set[Any], format_version: int, ) -> None: # Given @@ -461,7 +464,7 @@ def test_append_transform_partition_verify_partitions_count( identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" nested_field = arrow_table_date_timestamps_schema.find_field(part_col) partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), ) # When @@ -482,11 +485,77 @@ def test_append_transform_partition_verify_partitions_count( assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" - assert tbl.inspect.partitions().num_rows == expected_partitions + partitions_table = tbl.inspect.partitions() + assert partitions_table.num_rows == len(expected_partitions) + assert {part[part_col] for part in partitions_table['partition'].to_pylist()} == expected_partitions + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == len(expected_partitions) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_multiple_partitions( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_date_timestamps: pa.Table, + arrow_table_date_timestamps_schema: Schema, + format_version: int, +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions" + partition_spec = PartitionSpec( + PartitionField( + source_id=arrow_table_date_timestamps_schema.find_field("date").field_id, + field_id=1001, + transform=YearTransform(), + name="date_year", + ), + PartitionField( + source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id, + field_id=1000, + transform=HourTransform(), + name="timestamptz_hour", + ), + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_date_timestamps], + partition_spec=partition_spec, + schema=arrow_table_date_timestamps_schema, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 6, f"Expected 6 total rows for {identifier}" + for col in arrow_table_date_timestamps.column_names: + assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + partitions_table = tbl.inspect.partitions() + assert partitions_table.num_rows == 6 + partitions = partitions_table['partition'].to_pylist() + assert {(part["date_year"], part["timestamptz_hour"]) for part in partitions} == { + (53, 473328), + (54, 473352), + (54, 474072), + (54, 474096), + (54, 474102), + (None, None), + } files_df = spark.sql( f""" SELECT * FROM {identifier}.files """ ) - assert files_df.count() == expected_partitions + assert files_df.count() == 6 From 3711b1b5eb9e9a04b4d596d919920b729cdfbb9b Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Tue, 7 May 2024 13:24:08 +0000 Subject: [PATCH 19/21] adopt review feedback --- Makefile | 2 +- pyiceberg/transforms.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index de50374cfb..35051be9c1 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ test-integration: sleep 10 docker compose -f dev/docker-compose-integration.yml cp ./dev/provision.py spark-iceberg:/opt/spark/provision.py docker compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py - poetry run pytest tests/integration/test_writes/test_partitioned_writes.py -v -m integration ${PYTEST_ARGS} + poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} test-integration-rebuild: docker compose -f dev/docker-compose-integration.yml kill diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index c8af97c301..f4d0640d43 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -182,6 +182,9 @@ def __eq__(self, other: Any) -> bool: def supports_pyarrow_transform(self) -> bool: return False + @abstractmethod + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -297,6 +300,9 @@ def __repr__(self) -> str: """Return the string representation of the BucketTransform class.""" return f"BucketTransform(num_buckets={self._num_buckets})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + class TimeResolution(IntEnum): YEAR = 6 @@ -356,9 +362,6 @@ def dedup_name(self) -> str: def preserves_order(self) -> bool: return True - @abstractmethod - def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... - @property def supports_pyarrow_transform(self) -> bool: return True @@ -810,6 +813,9 @@ def __repr__(self) -> str: """Return the string representation of the TruncateTransform class.""" return f"TruncateTransform(width={self._width})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + @singledispatch def _human_string(value: Any, _type: IcebergType) -> str: @@ -892,6 +898,9 @@ def __repr__(self) -> str: """Return the string representation of the UnknownTransform class.""" return f"UnknownTransform(transform={repr(self._transform)})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + class VoidTransform(Transform[S, None], Singleton): """A transform that always returns None.""" @@ -920,6 +929,9 @@ def __repr__(self) -> str: """Return the string representation of the VoidTransform class.""" return "VoidTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + def _truncate_number( name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] From f16d77880b8d835caed554bb06d3bf605190ba2b Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Wed, 8 May 2024 22:12:57 +0000 Subject: [PATCH 20/21] comment --- pyiceberg/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index f4d0640d43..38cc6221a2 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -592,7 +592,7 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr elif isinstance(source, TimestamptzType): epoch = datetime.EPOCH_TIMESTAMPTZ else: - raise ValueError(f"Cannot apply month transform for type: {source}") + raise ValueError(f"Cannot apply hour transform for type: {source}") return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None From 9f0a92bfb45b4d4c5af4400a1d485826dc4449c5 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 31 May 2024 18:52:23 +0000 Subject: [PATCH 21/21] rebase --- tests/integration/test_writes/test_partitioned_writes.py | 4 ++-- tests/test_transforms.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 2f2aabc1fc..76d559ca57 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -487,7 +487,7 @@ def test_append_transform_partition_verify_partitions_count( partitions_table = tbl.inspect.partitions() assert partitions_table.num_rows == len(expected_partitions) - assert {part[part_col] for part in partitions_table['partition'].to_pylist()} == expected_partitions + assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions files_df = spark.sql( f""" SELECT * @@ -543,7 +543,7 @@ def test_append_multiple_partitions( partitions_table = tbl.inspect.partitions() assert partitions_table.num_rows == 6 - partitions = partitions_table['partition'].to_pylist() + partitions = partitions_table["partition"].to_pylist() assert {(part["date_year"], part["timestamptz_hour"]) for part in partitions} == { (53, 473328), (54, 473352), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 15ef7d0ea2..3a9ffd6009 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1815,7 +1815,7 @@ def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: @pytest.mark.parametrize( - 'transform', + "transform", [ pytest.param(YearTransform(), id="year_transform"), pytest.param(MonthTransform(), id="month_transform"),