From 10a7c41e8c98ec8f4546b866990ba713d1d34635 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 6 Aug 2025 20:22:37 +0200 Subject: [PATCH 1/3] Convert `_get_column_projection_values` to use Field-IDs --- pyiceberg/expressions/visitors.py | 19 ++++---- pyiceberg/io/pyarrow.py | 69 +++++++++++------------------- tests/expressions/test_visitors.py | 35 +++++++++++++-- tests/io/test_pyarrow.py | 6 +-- 4 files changed, 70 insertions(+), 59 deletions(-) diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 99cbc0fb66..779f2b476f 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -861,7 +861,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]): Args: file_schema (Schema): The schema of the file. case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True. - projected_field_values (Dict[str, Any]): Values for projected fields not present in the data file. + projected_field_values (Dict[int, Any]): Values for projected fields not present in the data file. Raises: TypeError: In the case of an UnboundPredicate. @@ -870,12 +870,12 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]): file_schema: Schema case_sensitive: bool - projected_field_values: Dict[str, Any] + projected_field_values: Dict[int, Any] - def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT) -> None: + def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT) -> None: self.file_schema = file_schema self.case_sensitive = case_sensitive - self.projected_field_values = projected_field_values or {} + self.projected_field_values = projected_field_values def visit_true(self) -> BooleanExpression: return AlwaysTrue() @@ -897,7 +897,8 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: field = predicate.term.ref().field - file_column_name = self.file_schema.find_column_name(field.field_id) + field_id = field.field_id + file_column_name = self.file_schema.find_column_name(field_id) if file_column_name is None: # In the case of schema evolution or column projection, the field might not be present in the file schema. @@ -915,8 +916,10 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi # In the order described by the "Column Projection" section of the Iceberg spec: # https://iceberg.apache.org/spec/#column-projection # Evaluate column projection first if it exists - if projected_field_value := self.projected_field_values.get(field.name): - if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(projected_field_value)): + if field_id in self.projected_field_values: + if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)( + Record(self.projected_field_values[field_id]) + ): return AlwaysTrue() # Evaluate initial_default value @@ -937,7 +940,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi def translate_column_names( - expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT + expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT ) -> BooleanExpression: return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values)) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index e6992843ca..0d8e998cc6 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -131,7 +131,6 @@ ) from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value from pyiceberg.schema import ( - Accessor, PartnerAccessor, PreOrderSchemaVisitor, Schema, @@ -1401,42 +1400,24 @@ def _field_id(self, field: pa.Field) -> int: def _get_column_projection_values( - file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int] -) -> Tuple[bool, Dict[str, Any]]: + file: DataFile, projected_schema: Schema, partition_spec: PartitionSpec, file_project_field_ids: Set[int] +) -> Dict[int, Any]: """Apply Column Projection rules to File Schema.""" project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids) - should_project_columns = len(project_schema_diff) > 0 - projected_missing_fields: Dict[str, Any] = {} + if len(project_schema_diff) == 0: + return EMPTY_DICT - if not should_project_columns: - return False, {} - - partition_schema: StructType - accessors: Dict[int, Accessor] - - if partition_spec is not None: - partition_schema = partition_spec.partition_type(projected_schema) - accessors = build_position_accessors(partition_schema) - else: - return False, {} + partition_schema = partition_spec.partition_type(projected_schema) + accessors = build_position_accessors(partition_schema) + projected_missing_fields = {} for field_id in project_schema_diff: for partition_field in partition_spec.fields_by_source_id(field_id): if isinstance(partition_field.transform, IdentityTransform): - accessor = accessors.get(partition_field.field_id) - - if accessor is None: - continue - - # The partition field may not exist in the partition record of the data file. - # This can happen when new partition fields are introduced after the file was written. - try: - if partition_value := accessor.get(file.partition): - projected_missing_fields[partition_field.name] = partition_value - except IndexError: - continue + if partition_value := accessors[partition_field.field_id].get(file.partition): + projected_missing_fields[field_id] = partition_value - return True, projected_missing_fields + return projected_missing_fields def _task_to_record_batches( @@ -1447,8 +1428,8 @@ def _task_to_record_batches( projected_field_ids: Set[int], positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, + partition_spec: PartitionSpec, name_mapping: Optional[NameMapping] = None, - partition_spec: Optional[PartitionSpec] = None, ) -> Iterator[pa.RecordBatch]: arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1460,9 +1441,8 @@ def _task_to_record_batches( # the table format version. file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True) - # Apply column projection rules - # https://iceberg.apache.org/spec/#column-projection - should_project_columns, projected_missing_fields = _get_column_projection_values( + # Apply column projection rules: https://iceberg.apache.org/spec/#column-projection + projected_missing_fields = _get_column_projection_values( task.file, projected_schema, partition_spec, file_schema.field_ids ) @@ -1517,16 +1497,9 @@ def _task_to_record_batches( file_project_schema, current_batch, downcast_ns_timestamp_to_us=True, + projected_missing_fields=projected_missing_fields, ) - # Inject projected column values if available - if should_project_columns: - for name, value in projected_missing_fields.items(): - index = result_batch.schema.get_field_index(name) - if index != -1: - arr = pa.repeat(value, result_batch.num_rows) - result_batch = result_batch.set_column(index, name, arr) - yield result_batch @@ -1695,8 +1668,8 @@ def _record_batches_from_scan_tasks_and_deletes( self._projected_field_ids, deletes_per_file.get(task.file.file_path), self._case_sensitive, - self._table_metadata.name_mapping(), self._table_metadata.spec(), + self._table_metadata.name_mapping(), ) for batch in batches: if self._limit is not None: @@ -1714,12 +1687,15 @@ def _to_requested_schema( batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False, + projected_missing_fields: Dict[int, Any] = EMPTY_DICT, ) -> pa.RecordBatch: # We could reuse some of these visitors struct_array = visit_with_partner( requested_schema, batch, - ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids), + ArrowProjectionVisitor( + file_schema, downcast_ns_timestamp_to_us, include_field_ids, projected_missing_fields=projected_missing_fields + ), ArrowAccessor(file_schema), ) return pa.RecordBatch.from_struct_array(struct_array) @@ -1730,6 +1706,7 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra _include_field_ids: bool _downcast_ns_timestamp_to_us: bool _use_large_types: Optional[bool] + _projected_missing_fields: Dict[int, Any] def __init__( self, @@ -1737,11 +1714,13 @@ def __init__( downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False, use_large_types: Optional[bool] = None, + projected_missing_fields: Dict[int, Any] = EMPTY_DICT, ) -> None: self._file_schema = file_schema self._include_field_ids = include_field_ids self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us self._use_large_types = use_large_types + self._projected_missing_fields = projected_missing_fields if use_large_types is not None: deprecation_message( @@ -1821,7 +1800,9 @@ def struct( elif field.optional or field.initial_default is not None: # When an optional field is added, or when a required field with a non-null initial default is added arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids) - if field.initial_default is None: + if projected_value := self._projected_missing_fields.get(field.field_id): + field_arrays.append(pa.repeat(projected_value, len(struct_array)).cast(arrow_type)) + elif field.initial_default is None: field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) else: field_arrays.append(pa.repeat(field.initial_default, len(struct_array))) diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index f02aadfe44..997cc7f7d7 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -1730,6 +1730,33 @@ def test_translate_column_names_missing_column_match_null() -> None: assert translated_expr == AlwaysTrue() +def test_translate_column_names_missing_column_match_explicit_null() -> None: + """Test translate_column_names when missing column matches null.""" + # Original schema + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = IsNull("missing_col") + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None}) + + # Should evaluate to AlwaysTrue because the missing column is treated as null + # missing_col's default initial_default (None) satisfies the IsNull predicate + assert translated_expr == AlwaysTrue() + + def test_translate_column_names_missing_column_with_initial_default() -> None: """Test translate_column_names when missing column's initial_default matches expression.""" # Original schema @@ -1801,7 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() -> ) # Projected column that is missing in the file schema - projected_field_values = {"missing_col": 42} + projected_field_values = {2: 42} # Translate column names translated_expr = translate_column_names( @@ -1833,7 +1860,7 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() - ) # Projected column that is missing in the file schema - projected_field_values = {"missing_col": 1} + projected_field_values = {2: 1} # Translate column names translated_expr = translate_column_names( @@ -1864,7 +1891,7 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init ) # Projected field value that differs from both the expression literal and initial_default - projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42) + projected_field_values = {2: 10} # This doesn't match expression literal (42) # Translate column names translated_expr = translate_column_names( @@ -1895,7 +1922,7 @@ def test_translate_column_names_missing_column_projected_field_matches_initial_d ) # Projected field value that matches the expression literal - projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42) + projected_field_values = {2: 10} # This doesn't match expression literal (42) # Translate column names translated_expr = translate_column_names( diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index ac16ef18f6..3bac2ff691 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1189,7 +1189,7 @@ def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCa with transaction.update_snapshot().overwrite() as update: update.append_data_file(unpartitioned_file) - schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int64())]) + schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int32())]) assert table.scan().to_arrow() == pa.table( { "other_field": ["foo", "bar", "baz"], @@ -1264,8 +1264,8 @@ def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryC str(table.scan().to_arrow()) == """pyarrow.Table field_1: string -field_2: int64 -field_3: int64 +field_2: int32 +field_3: int32 ---- field_1: [["foo"]] field_2: [[2]] From b29d80ee00128c4adb51bcb08d1d461291a64ace Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 7 Aug 2025 00:23:24 +0200 Subject: [PATCH 2/3] Fix the CI --- pyiceberg/io/pyarrow.py | 10 +++++----- .../integration/test_writes/test_partitioned_writes.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 0d8e998cc6..cee2ccac72 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1400,11 +1400,11 @@ def _field_id(self, field: pa.Field) -> int: def _get_column_projection_values( - file: DataFile, projected_schema: Schema, partition_spec: PartitionSpec, file_project_field_ids: Set[int] + file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int] ) -> Dict[int, Any]: """Apply Column Projection rules to File Schema.""" project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids) - if len(project_schema_diff) == 0: + if len(project_schema_diff) == 0 or partition_spec is None: return EMPTY_DICT partition_schema = partition_spec.partition_type(projected_schema) @@ -1428,8 +1428,8 @@ def _task_to_record_batches( projected_field_ids: Set[int], positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, - partition_spec: PartitionSpec, name_mapping: Optional[NameMapping] = None, + partition_spec: Optional[PartitionSpec] = None, ) -> Iterator[pa.RecordBatch]: arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1668,8 +1668,8 @@ def _record_batches_from_scan_tasks_and_deletes( self._projected_field_ids, deletes_per_file.get(task.file.file_path), self._case_sensitive, - self._table_metadata.spec(), self._table_metadata.name_mapping(), + self._table_metadata.specs().get(task.file.spec_id), ) for batch in batches: if self._limit is not None: @@ -1801,7 +1801,7 @@ def struct( # When an optional field is added, or when a required field with a non-null initial default is added arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids) if projected_value := self._projected_missing_fields.get(field.field_id): - field_arrays.append(pa.repeat(projected_value, len(struct_array)).cast(arrow_type)) + field_arrays.append(pa.repeat(pa.scalar(projected_value, type=arrow_type), len(struct_array))) elif field.initial_default is None: field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) else: diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index b2f6ad410d..e9698067c1 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -711,8 +711,10 @@ def test_dynamic_partition_overwrite_evolve_partition(spark: SparkSession, sessi ) identifier = f"default.partitioned_{format_version}_test_dynamic_partition_overwrite_evolve_partition" - with pytest.raises(NoSuchTableError): + try: session_catalog.drop_table(identifier) + except NoSuchTableError: + pass tbl = session_catalog.create_table( identifier=identifier, From cbd297f6e09f7ad256b41794509fb0260fcb576a Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 7 Aug 2025 00:35:02 +0200 Subject: [PATCH 3/3] Use the spec that it was written with --- tests/conftest.py | 4 +++- tests/io/test_pyarrow.py | 20 +++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c01ccc979c..16c9e06dac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2375,8 +2375,10 @@ def data_file(table_schema_simple: Schema, tmp_path: str) -> str: @pytest.fixture def example_task(data_file: str) -> FileScanTask: + datafile = DataFile.from_args(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925) + datafile.spec_id = 0 return FileScanTask( - data_file=DataFile.from_args(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925), + data_file=datafile, ) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 3bac2ff691..f5c3082edc 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -970,6 +970,10 @@ def file_map(schema_map: Schema, tmpdir: str) -> str: def project( schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None ) -> pa.Table: + def _set_spec_id(datafile: DataFile) -> DataFile: + datafile.spec_id = 0 + return datafile + return ArrowScan( table_metadata=TableMetadataV2( location="file://a/b/", @@ -985,13 +989,15 @@ def project( ).to_table( tasks=[ FileScanTask( - DataFile.from_args( - content=DataFileContent.DATA, - file_path=file, - file_format=FileFormat.PARQUET, - partition={}, - record_count=3, - file_size_in_bytes=3, + _set_spec_id( + DataFile.from_args( + content=DataFileContent.DATA, + file_path=file, + file_format=FileFormat.PARQUET, + partition={}, + record_count=3, + file_size_in_bytes=3, + ) ) ) for file in files