diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 5adf3a8a48..8829b4a3fb 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -383,6 +383,10 @@ def __repr__(self) -> str: @abstractmethod def as_bound(self) -> Type[BoundUnaryPredicate[Any]]: ... + def __hash__(self) -> int: + """Return hash value of the UnaryPredicate class.""" + return hash(str(self)) + class BoundUnaryPredicate(BoundPredicate[L], ABC): def __repr__(self) -> str: @@ -412,6 +416,10 @@ def __invert__(self) -> BoundNotNull[L]: def as_unbound(self) -> Type[IsNull]: return IsNull + def __hash__(self) -> int: + """Return hash value of the BoundIsNull class.""" + return hash(str(self)) + class BoundNotNull(BoundUnaryPredicate[L]): def __new__(cls, term: BoundTerm[L]): # type: ignore # pylint: disable=W0221 @@ -698,6 +706,10 @@ def __repr__(self) -> str: @abstractmethod def as_bound(self) -> Type[BoundLiteralPredicate[L]]: ... + def __hash__(self) -> int: + """Return hash value of the UnaryPredicate class.""" + return hash(str(self)) + class BoundLiteralPredicate(BoundPredicate[L], ABC): literal: Literal[L] @@ -731,6 +743,10 @@ def __invert__(self) -> BoundNotEqualTo[L]: def as_unbound(self) -> Type[EqualTo[L]]: return EqualTo + def __hash__(self) -> int: + """Return hash value of the BoundEqualTo class.""" + return hash(str(self)) + class BoundNotEqualTo(BoundLiteralPredicate[L]): def __invert__(self) -> BoundEqualTo[L]: diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 5c26f1f96c..43a7132175 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1725,6 +1725,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' # generate_data_file_filename schema = table_metadata.schema() + arrow_file_schema = schema_to_pyarrow(schema) fo = io.new_output(file_path) @@ -1735,7 +1736,9 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT ) with fo.create(overwrite=True) as fos: with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: - writer.write_table(task.df, row_group_size=row_group_size) + # align the columns accordingly in case input arrow table has columns in order different from iceberg table + df_to_write = task.df.select(arrow_file_schema.names) + writer.write_table(df_to_write, row_group_size=row_group_size) data_file = DataFile( content=DataFileContent.DATA, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 42ca85cccd..c60850c1af 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -51,10 +51,15 @@ from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError from pyiceberg.expressions import ( + AlwaysFalse, AlwaysTrue, And, BooleanExpression, + BoundEqualTo, + BoundIsNull, + BoundPredicate, EqualTo, + IsNull, Reference, parser, visitors, @@ -115,14 +120,7 @@ ) from pyiceberg.table.sorting import SortOrder from pyiceberg.transforms import TimeTransform, Transform, VoidTransform -from pyiceberg.typedef import ( - EMPTY_DICT, - IcebergBaseModel, - IcebergRootModel, - Identifier, - KeyDefaultDict, - Properties, -) +from pyiceberg.typedef import EMPTY_DICT, IcebergBaseModel, IcebergRootModel, Identifier, KeyDefaultDict, L, Properties from pyiceberg.types import ( IcebergType, ListType, @@ -150,7 +148,97 @@ _JAVA_LONG_MAX = 9223372036854775807 -def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: +def _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr: Union[IsNull, EqualTo[L]], table_schema: Schema, spec: PartitionSpec +) -> Union[BoundIsNull[L], BoundEqualTo[L]]: + # step 1: check the unbound_expr is within the schema and the value matches the schema + bound_expr: Union[BoundIsNull[L], BoundEqualTo[L], AlwaysFalse] = unbound_expr.bind(table_schema) # type: ignore # The bind returns upcast types. + + # step 2: check non nullable column is not partitioned overwriten with isNull. + # It has to break because we cannot fill null values into input arrow table (and parquets to write) for an iceberg field which is non-nullable. + if isinstance(bound_expr, AlwaysFalse): + raise ValueError( + "Static overwriting with part of the explicit partition filter not meaningful (specifing a non-nullable partition field to be null)." + ) + + # step 3: check the unbound_expr is within the partition spec + if not isinstance(bound_expr, (BoundIsNull, BoundEqualTo)): + raise ValueError( + f"{unbound_expr=} binds to {bound_expr=} whose type is not expected. Expecting BoundIsNull or BoundEqualTo" + ) + nested_field: NestedField = bound_expr.term.ref().field + part_fields: List[PartitionField] = spec.fields_by_source_id(nested_field.field_id) + if len(part_fields) != 1: + raise ValueError(f"Get {len(part_fields)} partition fields from filter predicate {str(unbound_expr)}, expecting 1.") + part_field = part_fields[0] + + # step 4: check the unbound_expr is with identity transform + if not isinstance(part_field.transform, IdentityTransform): + raise ValueError( + f"static overwrite partition filter can only apply to partition fields which are without hidden transform, but get {part_field.transform=} for {nested_field=}" + ) + + return bound_expr + + +def _validate_static_overwrite_filter( + table_schema: Schema, overwrite_filter: BooleanExpression, spec: PartitionSpec +) -> Tuple[Set[BoundIsNull[L]], Set[BoundEqualTo[L]]]: + is_null_predicates, eq_to_predicates = _validate_static_overwrite_filter_expr_type(expr=overwrite_filter) + + bound_is_null_preds = set() + bound_eq_to_preds = set() + for unbound_is_null in is_null_predicates: + bound_pred = _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=unbound_is_null, table_schema=table_schema, spec=spec + ) + if not isinstance(bound_pred, BoundIsNull): + raise ValueError(f"Expecting IsNull after binding {unbound_is_null} to schema but get {bound_pred}.") + bound_is_null_preds.add(bound_pred) + + for unbound_eq_to in eq_to_predicates: + bound_pred = _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=unbound_eq_to, table_schema=table_schema, spec=spec + ) + if not isinstance(bound_pred, BoundEqualTo): + raise ValueError(f"Expecting IsNull after binding {unbound_eq_to} to schema but get {bound_pred}.") + bound_eq_to_preds.add(bound_pred) + return (bound_is_null_preds, bound_eq_to_preds) # type: ignore + + +def _fill_in_df( + df: pa.Table, bound_is_null_predicates: Set[BoundIsNull[L]], bound_eq_to_predicates: Set[BoundEqualTo[L]] +) -> pa.Table: + """Use bound filter predicates to extend the pyarrow with correct schema matching the iceberg schema and fill in the values.""" + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + is_null_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_is_null_predicates] + eq_to_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_eq_to_predicates] + + schema = Schema(*chain(is_null_nested_fields, eq_to_nested_fields)) + from pyiceberg.io.pyarrow import schema_to_pyarrow + + pa_schema = schema_to_pyarrow(schema) + + is_null_nested_field_name_values = zip( + [nested_field.name for nested_field in is_null_nested_fields], [None] * len(bound_is_null_predicates) + ) + eq_to_nested_field_name_values = zip( + [nested_field.name for nested_field in eq_to_nested_fields], + [predicate.literal.value for predicate in bound_eq_to_predicates], + ) + + for field_name, value in chain(is_null_nested_field_name_values, eq_to_nested_field_name_values): + pa_field = pa_schema.field(field_name) + literal_array = pa.array([value] * df.num_rows, type=pa_field.type) + df = df.add_column(df.num_columns, field_name, literal_array) + return df + + +def _arrow_schema_to_iceberg_schema_with_field_ids(table_schema: Schema, other_schema: "pa.Schema") -> Schema: from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -162,8 +250,56 @@ def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e + return task_schema + + +def _check_schema_with_filter_predicates( + table_schema: Schema, other_schema: "pa.Schema", filter_predicates: Set[BoundPredicate[L]] +) -> None: + task_schema = _arrow_schema_to_iceberg_schema_with_field_ids(table_schema, other_schema) + + filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] + remaining_schema = _truncate_fields(table_schema, to_truncate=filter_fields) + sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field: field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) + if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): + from rich.console import Console + from rich.table import Table as RichTable + + console = Console(record=True) + + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") + rich_table.add_column("Overwrite filter field") + + filter_field_names = [field.name for field in filter_fields] + for lhs in table_schema.fields: + if lhs.name in filter_field_names: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("❌", str(lhs), str(rhs), lhs.name) + except ValueError: + rich_table.add_row("✅", str(lhs), "N/A", lhs.name) + else: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs), "N/A") + except ValueError: + rich_table.add_row("❌", str(lhs), "Missing", "N/A") + + console.print(rich_table) + raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + + +def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: + task_schema = _arrow_schema_to_iceberg_schema_with_field_ids(table_schema, other_schema) + + sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field: field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) - if table_schema.as_struct() != task_schema.as_struct(): + if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): from rich.console import Console from rich.table import Table as RichTable @@ -185,6 +321,136 @@ def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: raise ValueError(f"Mismatch in fields:\n{console.export_text()}") +# def _check_schema( +# table_schema: Schema, other_schema: "pa.Schema", filter_predicates: Set[BoundPredicate[L]] | None = None +# ) -> None: +# if filter_predicates is None: +# filter_predicates = set() + +# from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + +# name_mapping = table_schema.name_mapping +# try: +# task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) +# except ValueError as e: +# other_schema = _pyarrow_to_schema_without_ids(other_schema) +# additional_names = set(other_schema.column_names) - set(table_schema.column_names) +# raise ValueError( +# f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." +# ) from e + +# def compare_and_rich_print(table_schema: Schema, task_schema: Schema) -> None: +# sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field: field.field_id)) +# sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) +# if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): +# from rich.console import Console +# from rich.table import Table as RichTable + +# console = Console(record=True) + +# rich_table = RichTable(show_header=True, header_style="bold") +# rich_table.add_column("") +# rich_table.add_column("Table field") +# rich_table.add_column("Dataframe field") + +# for lhs in table_schema.fields: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) +# except ValueError: +# rich_table.add_row("❌", str(lhs), "Missing") + +# console.print(rich_table) +# raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + +# def compare_and_rich_print_with_filter( +# table_schema: Schema, task_schema: Schema, filter_predicates: Set[BoundPredicate[L]] +# ) -> None: +# filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] +# remaining_schema = _truncate_fields(table_schema, to_truncate=filter_fields) +# sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field: field.field_id)) +# sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) +# if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): +# from rich.console import Console +# from rich.table import Table as RichTable + +# console = Console(record=True) + +# rich_table = RichTable(show_header=True, header_style="bold") +# rich_table.add_column("") +# rich_table.add_column("Table field") +# rich_table.add_column("Dataframe field") +# rich_table.add_column("Overwrite filter field") + +# filter_field_names = [field.name for field in filter_fields] +# for lhs in table_schema.fields: +# if lhs.name in filter_field_names: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("❌", str(lhs), str(rhs), lhs.name) +# except ValueError: +# rich_table.add_row("✅", str(lhs), "N/A", lhs.name) +# else: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs), "N/A") +# except ValueError: +# rich_table.add_row("❌", str(lhs), "Missing", "N/A") + +# console.print(rich_table) +# raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + +# if len(filter_predicates) != 0: +# compare_and_rich_print_with_filter(table_schema, task_schema, filter_predicates) +# else: +# compare_and_rich_print(table_schema, task_schema) + + +def _truncate_fields(table_schema: Schema, to_truncate: List[NestedField]) -> Schema: + to_truncate_fields_source_ids = {nested_field.field_id for nested_field in to_truncate} + truncated = [field for field in table_schema.fields if field.field_id not in to_truncate_fields_source_ids] + return Schema(*truncated) + + +def _validate_static_overwrite_filter_expr_type(expr: BooleanExpression) -> Tuple[Set[IsNull], Set[EqualTo[L]]]: + """Validate whether expression only has 1)And 2)IsNull and 3)EqualTo and break down the raw expression into IsNull and EqualTo.""" + from collections import defaultdict + + def _recursively_fetch_fields( + expr: BooleanExpression, is_null_predicates: Set[IsNull], eq_to_predicates: Set[EqualTo[L]] + ) -> None: + if isinstance(expr, EqualTo): + if not isinstance(expr.term, Reference): + raise ValueError(f"Unsupported unbound term {expr.term} in {expr}, expecting a refernce.") + duplication_check[expr.term.name].add(expr) + eq_to_predicates.add(expr) + elif isinstance(expr, IsNull): + if not isinstance(expr.term, Reference): + raise ValueError(f"Unsupported unbound term {expr.term} in {expr}, expecting a refernce.") + duplication_check[expr.term.name].add(expr) + is_null_predicates.add(expr) + elif isinstance(expr, And): + _recursively_fetch_fields(expr.left, is_null_predicates, eq_to_predicates) + _recursively_fetch_fields(expr.right, is_null_predicates, eq_to_predicates) + else: + raise ValueError( + f"static overwrite partitioning filter can only be isequalto, is null, and, alwaysTrue, but get {expr=}" + ) + + duplication_check: Dict[str, Set[Union[IsNull, EqualTo[L]]]] = defaultdict(set) + is_null_predicates: Set[IsNull] = set() + eq_to_predicates: Set[EqualTo[L]] = set() + _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) + for _, expr_set in duplication_check.items(): + if len(expr_set) != 1: + raise ValueError( + f"static overwrite partitioning filter has more than 1 different predicates with same field {expr_set}" + ) + + # check fields don't step into itself, and do not step into each other, maybe we could move this to other 1+3(here) fields check + return is_null_predicates, eq_to_predicates + + class TableProperties: PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes" PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB @@ -238,11 +504,11 @@ class PartitionProjector: def __init__( self, table_metadata: TableMetadata, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + row_filter: BooleanExpression = ALWAYS_TRUE, case_sensitive: bool = True, ): self.table_metadata = table_metadata - self.row_filter = _parse_row_filter(row_filter) + self.row_filter = _parse_row_filter(row_filter) # todo make it BooleanExpression self.case_sensitive = case_sensitive def _build_partition_projection(self, spec_id: int) -> BooleanExpression: @@ -1159,7 +1425,21 @@ def overwrite(self, df: pa.Table, overwrite_filter: Union[str, BooleanExpression if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - _check_schema(self.schema(), other_schema=df.schema) + overwrite_filter = _parse_row_filter(overwrite_filter) + + if not overwrite_filter == ALWAYS_TRUE: + bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter( + table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec() + ) + + _check_schema_with_filter_predicates( + table_schema=self.schema(), + other_schema=df.schema, + filter_predicates=bound_is_null_predicates.union(bound_eq_to_predicates), + ) + df = _fill_in_df(df, bound_is_null_predicates, bound_eq_to_predicates) + else: + _check_schema(table_schema=self.schema(), other_schema=df.schema) with self.transaction() as txn: with txn.update_snapshot().overwrite(overwrite_filter) as update_snapshot: @@ -2480,17 +2760,16 @@ class _MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]): _io: FileIO _deleted_data_files: Optional[DeletedDataFiles] - # _manifests_compositions: Any #list[Callable[[_MergingSnapshotProducer], List[ManifestFile]]] def __init__( self, - operation: Operation, # done, inited + operation: Operation, transaction: Transaction, - io: FileIO, # done, inited - overwrite_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + io: FileIO, + overwrite_filter: BooleanExpression = ALWAYS_TRUE, commit_uuid: Optional[uuid.UUID] = None, ) -> None: super().__init__(transaction) - self.commit_uuid = commit_uuid or uuid.uuid4() # done + self.commit_uuid = commit_uuid or uuid.uuid4() self._io = io self._operation = operation self._snapshot_id = self._transaction.table_metadata.new_snapshot_id() @@ -2891,7 +3170,7 @@ def _get_deleted_entries(manifest: ManifestFile) -> List[ManifestEntry]: class PartialOverwriteFiles(_MergingSnapshotProducer): - def __init__(self, overwrite_filter: Union[str, BooleanExpression], **kwargs: Any) -> None: + def __init__(self, overwrite_filter: BooleanExpression, **kwargs: Any) -> None: super().__init__(**kwargs) self._deleted_data_files = ExplicitlyDeletedDataFiles() self.overwrite_filter = overwrite_filter @@ -2975,9 +3254,7 @@ def __init__(self, transaction: Transaction, io: FileIO) -> None: def fast_append(self) -> FastAppendFiles: return FastAppendFiles(operation=Operation.APPEND, transaction=self._transaction, io=self._io) - def overwrite( - self, overwrite_filter: Union[str, BooleanExpression] = ALWAYS_TRUE - ) -> Union[OverwriteFiles, PartialOverwriteFiles]: + def overwrite(self, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> Union[OverwriteFiles, PartialOverwriteFiles]: if overwrite_filter == ALWAYS_TRUE: return OverwriteFiles( operation=Operation.OVERWRITE diff --git a/tests/integration/test_partitioned_writes.py b/tests/integration/test_partitioned_writes.py index 7b50ab91d9..ea4c03e8d6 100644 --- a/tests/integration/test_partitioned_writes.py +++ b/tests/integration/test_partitioned_writes.py @@ -444,7 +444,8 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro tbl.append(arrow_table_with_null) tbl.overwrite(arrow_table_with_null) tbl.append(arrow_table_with_null) - tbl.overwrite(arrow_table_with_null, overwrite_filter="int=1") + valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop(['int']) + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") rows = spark.sql( f""" @@ -512,16 +513,16 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro # static overwrite which deletes 2 record (one from step3, one from step4) and 2 datafile, adding 3 new data files and 3 records, so total data files and records are 6 - 2 + 3 = 7 assert summaries[4] == { 'removed-files-size': '10790', - 'added-data-files': '3', + 'added-data-files': '1', 'total-equality-deletes': '0', 'added-records': '3', 'deleted-data-files': '2', 'total-position-deletes': '0', - 'added-files-size': '15029', + 'added-files-size': '5455', 'total-delete-files': '0', 'deleted-records': '2', - 'total-files-size': '34297', - 'total-data-files': '7', + 'total-files-size': '24723', + 'total-data-files': '5', 'total-records': '7', } @@ -546,7 +547,9 @@ def test_data_files_with_table_partitioned_with_null( tbl.append(arrow_table_with_null) tbl.overwrite(arrow_table_with_null) tbl.append(arrow_table_with_null) - tbl.overwrite(arrow_table_with_null, overwrite_filter="int=1") + + valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop(['int']) + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") # first append links to 1 manifest file (M1) # second append's manifest list links to 2 manifest files (M1, M2) @@ -561,7 +564,7 @@ def test_data_files_with_table_partitioned_with_null( # M3 0 0 6 S3 # M4 3 0 0 S3 # M5 3 0 0 S4 - # M6 3 0 0 S5 + # M6 1 0 0 S5 # M7 0 4 2 S5 spark.sql( @@ -575,7 +578,7 @@ def test_data_files_with_table_partitioned_with_null( FROM {identifier}.all_manifests """ ).collect() - assert [row.added_data_files_count for row in rows] == [3, 3, 3, 3, 0, 3, 3, 3, 0] + assert [row.added_data_files_count for row in rows] == [3, 3, 3, 3, 0, 3, 3, 1, 0] assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0, 0, 0, 0, 4] assert [row.deleted_data_files_count for row in rows] == [0, 0, 0, 0, 6, 0, 0, 0, 2] @@ -668,11 +671,19 @@ def test_query_filter_after_append_overwrite_table_with_expr( properties={'format-version': '1'}, ) - for _ in range(2): + for _ in range(3): tbl.append(arrow_table_with_null) - tbl.overwrite(arrow_table_with_null, expr) + spark.sql(f"refresh table {identifier}") + spark.sql(f"select file_path from {identifier}.files").show(20, False) + spark.sql(f"select * from {identifier}").show(20, False) + + valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop([part_col]) + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, expr) iceberg_table = session_catalog.load_table(identifier=identifier) + spark.sql(f"refresh table {identifier}") + spark.sql(f"select file_path from {identifier}.files").show(20, False) spark.sql(f"select * from {identifier}").show(20, False) - assert iceberg_table.scan(row_filter=expr).to_arrow().num_rows == 1 - assert iceberg_table.scan().to_arrow().num_rows == 7 + + assert iceberg_table.scan().to_arrow().num_rows == 9 + assert iceberg_table.scan(row_filter=expr).to_arrow().num_rows == 3 diff --git a/tests/table/test_init.py b/tests/table/test_init.py index be3a28199a..f5519a0f7e 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +import re import uuid from copy import copy -from typing import Any, Dict +from typing import Any, Dict, Set, Union import pyarrow as pa import pytest @@ -29,8 +30,17 @@ from pyiceberg.expressions import ( AlwaysTrue, And, + BoundEqualTo, + BoundIsNull, + BoundPredicate, EqualTo, In, + IsNull, + LessThan, + NotEqualTo, + NotNull, + Or, + Reference, ) from pyiceberg.io import PY_IO_IMPL, load_file_io from pyiceberg.manifest import ( @@ -62,9 +72,13 @@ Table, UpdateSchema, _apply_table_update, + _bind_and_validate_static_overwrite_filter_predicate, _check_schema, + _check_schema_with_filter_predicates, + _fill_in_df, _match_deletes_to_data_file, _TableMetadataUpdateContext, + _validate_static_overwrite_filter_expr_type, update_table_metadata, verify_table_already_sorted, ) @@ -81,7 +95,8 @@ SortField, SortOrder, ) -from pyiceberg.transforms import BucketTransform, IdentityTransform +from pyiceberg.transforms import BucketTransform, IdentityTransform, TruncateTransform +from pyiceberg.typedef import L from pyiceberg.types import ( BinaryType, BooleanType, @@ -1014,7 +1029,307 @@ def test_correct_schema() -> None: assert "Snapshot not found: -1" in str(exc_info.value) -def test_schema_mismatch_type(table_schema_simple: Schema) -> None: +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_schema_fields_in_filter( + iceberg_schema_simple: Schema, +) -> None: + pred = EqualTo(Reference("not a field"), "hello") + partition_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="test_part_col") + ) + with pytest.raises(ValueError, match="Could not find field with name not a field, case_sensitive=True"): + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) + + +def test__fill_in_df(table_schema_simple: Schema) -> None: + df = pa.table({"baz": [True, False, None]}) + unbound_is_null_predicates = [IsNull(Reference("foo"))] + unbound_eq_to_predicates = [EqualTo(Reference("bar"), 3)] + bound_is_null_predicates: Set[BoundIsNull[Any]] = { + unbound_predicate.bind(table_schema_simple) # type: ignore # because bind returns super type and python could not downcast implicitly using type annotation + for unbound_predicate in unbound_is_null_predicates + } + bound_eq_to_predicates: Set[BoundEqualTo[Any]] = { + unbound_predicate.bind(table_schema_simple) # type: ignore # because bind returns super type and python could not downcast implicitly using type annotation + for unbound_predicate in unbound_eq_to_predicates + } + filled_df = _fill_in_df( + df=df, + bound_is_null_predicates=bound_is_null_predicates, + bound_eq_to_predicates=bound_eq_to_predicates, + ) + expected = pa.table( + {"baz": [True, False, None], "foo": [None, None, None], "bar": [3, 3, 3]}, + schema=pa.schema([pa.field('baz', pa.bool_()), pa.field('foo', pa.string()), pa.field('bar', pa.int32())]), + ) + assert filled_df == expected + + +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_part_fields_in_filter( + iceberg_schema_simple: Schema, +) -> None: + pred = EqualTo(Reference("foo"), "hello") + partition_spec = PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar")) + import re + + with pytest.raises( + ValueError, + match=re.escape( + "Get 0 partition fields from filter predicate EqualTo(term=Reference(name='foo'), literal=literal('hello')), expecting 1." + ), + ): + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) + + +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_identity_transorm_filter( + iceberg_schema_simple: Schema, +) -> None: + pred = EqualTo(Reference("foo"), "hello") + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), + ) + # import re + with pytest.raises( + ValueError, + match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*", + ): + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) + + +def test__bind_and_validate_static_overwrite_filter_predicate_succeeds_on_an_identity_transform_field_although_table_has_other_hidden_partition_fields( + iceberg_schema_simple: Schema, +) -> None: + pred = EqualTo(Reference("bar"), 3) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), + ) + + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) + + +def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_to_incompatible_predicate_value( + iceberg_schema_simple: Schema, +) -> None: + pred = EqualTo(Reference("bar"), "an incompatible type") + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), + ) + with pytest.raises(ValueError, match="Could not convert an incompatible type into a int"): + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) + + +def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_to_non_nullable( + iceberg_schema_simple: Schema, +) -> None: + pred = IsNull(Reference("bar")) + partition_spec = PartitionSpec( + PartitionField(source_id=3, field_id=1001, transform=IdentityTransform(), name="baz"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), + ) + with pytest.raises( + ValueError, + match=re.escape( + "Static overwriting with part of the explicit partition filter not meaningful (specifing a non-nullable partition field to be null)" + ), + ): + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) + + +def test__check_schema_with_filter_succeed(iceberg_schema_simple: Schema) -> None: + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('baz', pa.bool_()), + ]) + + unbound_preds = [EqualTo(Reference("bar"), 15)] + filter_predicates: Set[BoundPredicate[int]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] + # because _check_schema expects set[BoundPredicate[int]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + + +def test__check_schema_with_filter_succeed_on_pyarrow_table_with_random_column_order(iceberg_schema_simple: Schema) -> None: + other_schema: pa.Schema = pa.schema([ + pa.field('baz', pa.bool_()), + pa.field('foo', pa.string()), + ]) + + unbound_preds = [EqualTo(Reference("bar"), 15)] + filter_predicates: Set[BoundPredicate[int]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] + # because _check_schema expects set[BoundPredicate[int]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + + +def test__check_schema_with_filter_fails_on_missing_field(iceberg_schema_simple: Schema) -> None: + other_schema: pa.Schema = pa.schema([ + pa.field('baz', pa.bool_()), + ]) + + unbound_preds = [EqualTo(Reference("bar"), 15)] + filter_predicates: Set[BoundPredicate[int]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] + # because _check_schema expects set[BoundPredicate[int]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + + expected = re.escape( + """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ❌ │ 1: foo: optional string │ Missing │ N/A │ +│ ✅ │ 2: bar: required int │ N/A │ bar │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ N/A │ +└────┴──────────────────────────┴──────────────────────────┴────────────────────────┘ +""" + ) + with pytest.raises(ValueError, match=expected): + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + + +def test__check_schema_with_filter_fails_on_nullability_mismatch(iceberg_schema_simple: Schema) -> None: + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('bar', pa.int32()), + ]) + + unbound_preds = [EqualTo(Reference("baz"), True)] + filter_predicates: Set[BoundPredicate[bool]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[bool]] to set[BoundPredicate[bool]] + # because _check_schema expects set[BoundPredicate[bool]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + expected = re.escape( + """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ N/A │ +│ ❌ │ 2: bar: required int │ 2: bar: optional int │ N/A │ +│ ✅ │ 3: baz: optional boolean │ N/A │ baz │ +└────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ +""" + ) + with pytest.raises(ValueError, match=expected): + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + + +def test__check_schema_with_filter_fails_on_type_mismatch(iceberg_schema_simple: Schema) -> None: + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('bar', pa.string(), nullable=False), + ]) + + unbound_preds = [EqualTo(Reference("baz"), True)] + filter_predicates: Set[BoundPredicate[bool]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[bool]] to set[BoundPredicate[bool]] + # because _check_schema expects set[BoundPredicate[bool]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + expected = re.escape( + """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ N/A │ +│ ❌ │ 2: bar: required int │ 2: bar: required string │ N/A │ +│ ✅ │ 3: baz: optional boolean │ N/A │ baz │ +└────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ +""" + ) + with pytest.raises(ValueError, match=expected): + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + + +def test__check_schema_with_filter_fails_due_to_filter_and_dataframe_holding_same_field(iceberg_schema_simple: Schema) -> None: + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('bar', pa.int32(), nullable=False), + ]) + + unbound_preds = [IsNull(Reference("foo")), EqualTo(Reference("baz"), True)] + filter_predicates: Set[BoundPredicate[Any]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} # type: ignore # bind returns BoundLiteralPredicate and BoundUnaryPredicate and thus set has type inferred as Set[BooleanExpression] which could not be downcast to Set[BoundPredicate[Any]] implicitly using :. + expected = re.escape( + """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ❌ │ 1: foo: optional string │ 1: foo: optional string │ foo │ +│ ✅ │ 2: bar: required int │ 2: bar: required int │ N/A │ +│ ✅ │ 3: baz: optional boolean │ N/A │ baz │ +└────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ +""" + ) + with pytest.raises(ValueError, match=expected): + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + + +@pytest.mark.parametrize( + "pred, raises, is_null_preds, eq_to_preds", + [ + (EqualTo(Reference("foo"), "hello"), False, {}, {EqualTo(Reference("foo"), "hello")}), + (IsNull(Reference("foo")), False, {IsNull(Reference("foo"))}, {}), + ( + And(IsNull(Reference("foo")), EqualTo(Reference("boo"), "hello")), + False, + {IsNull(Reference("foo"))}, + {EqualTo(Reference("boo"), "hello")}, + ), + (NotNull, True, {}, {}), + (NotEqualTo, True, {}, {}), + (LessThan(Reference("foo"), 5), True, {}, {}), + (Or(IsNull(Reference("foo")), EqualTo(Reference("foo"), "hello")), True, {}, {}), + ( + And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("baz")), EqualTo(Reference("boo"), "hello"))), + False, + {IsNull(Reference("baz"))}, + {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")}, + ), + # Below are crowd-crush tests: a same field can only be with same literal/null, not different literals or both literal and null + # A false crush: when there are duplicated isnull/equalto, the collector should deduplicate them. + ( + And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "hello")), + False, + {}, + {EqualTo(Reference("foo"), "hello")}, + ), + # When crush happens + ( + And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")), + True, + {}, + {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")}, + ), + (And(EqualTo(Reference("foo"), "hello"), IsNull(Reference("foo"))), True, {IsNull(Reference("foo"))}, {}), + ], +) +def test__validate_static_overwrite_filter_expr_type( + pred: Union[IsNull, EqualTo[Any]], raises: bool, is_null_preds: Set[IsNull], eq_to_preds: Set[EqualTo[L]] +) -> None: + if raises: + with pytest.raises(ValueError): + res = _validate_static_overwrite_filter_expr_type(pred) + else: + res = _validate_static_overwrite_filter_expr_type(pred) + assert {str(e) for e in res[0]} == {str(e) for e in is_null_preds} + assert {str(e) for e in res[1]} == {str(e) for e in eq_to_preds} + + +def test_check_schema_mismatch_type(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("bar", pa.decimal128(18, 6), nullable=False), @@ -1035,7 +1350,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) -def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: +def test_check_schema_mismatch_nullability(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("bar", pa.int32(), nullable=True), @@ -1056,7 +1371,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) -def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: +def test_check_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("baz", pa.bool_(), nullable=True), @@ -1076,6 +1391,26 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) +def test_check_schema_succeed(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + _check_schema(table_schema_simple, other_schema) + + +def test_schema_succeed_on_pyarrow_table_reversed_column_order(iceberg_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("baz", pa.bool_(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("foo", pa.string(), nullable=True), + )) + + _check_schema(iceberg_schema_simple, other_schema) + + def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), @@ -1116,7 +1451,6 @@ def test_table_properties_raise_for_none_value(example_table_metadata_v2: Dict[s assert "None type is not a supported value in properties: property_name" in str(exc_info.value) -@pytest.mark.integration @pytest.mark.parametrize( "input_sorted_indices, expected_sorted_or_not", [