diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f2d60e7534..78e28d0aaa 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -25,9 +25,11 @@ from __future__ import annotations import concurrent.futures +import itertools import logging import os import re +import warnings from abc import ABC, abstractmethod from concurrent.futures import Future from dataclasses import dataclass @@ -110,6 +112,7 @@ Schema, SchemaVisitorPerPrimitiveType, SchemaWithPartnerVisitor, + assign_fresh_schema_ids, pre_order_visit, promote, prune_columns, @@ -616,7 +619,12 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: def pyarrow_to_schema(schema: pa.Schema) -> Schema: visitor = _ConvertToIceberg() - return visit_pyarrow(schema, visitor) + schema = visit_pyarrow(schema, visitor) + + if visitor.missing_id_metadata: + return assign_fresh_schema_ids(schema) + else: + return schema @singledispatch @@ -713,28 +721,51 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]: """Visit a primitive type.""" -def _get_field_id(field: pa.Field) -> Optional[int]: - for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: - if field_id_str := field.metadata.get(pyarrow_field_id_key): - return int(field_id_str.decode()) - return None +class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): + counter: itertools.count[int] + missing_id_metadata: Optional[bool] + def __init__(self) -> None: + self.counter = itertools.count(1) + self.missing_id_metadata = None + + def _get_field_id(self, field: pa.Field) -> int: + field_id: Optional[int] = None + + for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: + if field.metadata and (field_id_str := field.metadata.get(pyarrow_field_id_key)): + field_id = int(field_id_str.decode()) + + if field_id is None: + if self.missing_id_metadata is None: + warnings.warn( + "Missing field-IDs will be auto-assigned, possibly leading to inconsistencies between the file schema and the schema stored in table metadata." + ) + field_id = next(self.counter) + missing_is_metadata = True + else: + missing_is_metadata = False -def _get_field_doc(field: pa.Field) -> Optional[str]: - for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: - if doc_str := field.metadata.get(pyarrow_doc_key): - return doc_str.decode() - return None + if self.missing_id_metadata is not None and self.missing_id_metadata != missing_is_metadata: + raise ValueError("Parquet file contains partial field-ids") + else: + self.missing_id_metadata = missing_is_metadata + return field_id + + def _get_field_doc(self, field: pa.Field) -> Optional[str]: + for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: + if field.metadata and (doc_str := field.metadata.get(pyarrow_doc_key)): + return doc_str.decode() + return None -class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]: fields = [] for i, field in enumerate(arrow_fields): - field_id = _get_field_id(field) - field_doc = _get_field_doc(field) + field_id = self._get_field_id(field) + field_doc = self._get_field_doc(field) field_type = field_results[i] - if field_type is not None and field_id is not None: + if field_type is not None: fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)) return fields @@ -746,7 +777,7 @@ def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: element_field = list_type.value_field - element_id = _get_field_id(element_field) + element_id = self._get_field_id(element_field) if element_result is not None and element_id is not None: return ListType(element_id, element_result, element_required=not element_field.nullable) return None @@ -755,9 +786,9 @@ def map( self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] ) -> Optional[IcebergType]: key_field = map_type.key_field - key_id = _get_field_id(key_field) + key_id = self._get_field_id(key_field) value_field = map_type.item_field - value_id = _get_field_id(value_field) + value_id = self._get_field_id(value_field) if key_result is not None and value_result is not None and key_id is not None and value_id is not None: return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) return None diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 5194d8660e..f393aae652 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=protected-access,unused-argument,redefined-outer-name import re +from unittest.mock import Mock, patch import pyarrow as pa import pytest @@ -269,3 +270,32 @@ def test_round_schema_conversion_nested(table_schema_nested: Schema) -> None: 15: person: optional struct<16: name: optional string, 17: age: required int> }""" assert actual == expected + + +@patch("warnings.warn") +def test_schema_to_pyarrow_schema_missing_ids(warn: Mock) -> None: + schema = pa.schema([pa.field('some_int', pa.int32(), nullable=True), pa.field('some_string', pa.string(), nullable=False)]) + actual = pyarrow_to_schema(schema) + + expected = Schema( + NestedField(field_id=1, name="some_int", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="some_string", field_type=StringType(), required=True), + ) + + assert actual == expected + assert warn.called + + +@patch("warnings.warn") +def test_schema_to_pyarrow_schema_missing_id(warn: Mock) -> None: + schema = pa.schema( + [ + pa.field('some_int', pa.int32(), nullable=True), + pa.field('some_string', pa.string(), nullable=False, metadata={b"field_id": "22"}), + ] + ) + + with pytest.raises(ValueError) as exc_info: + _ = pyarrow_to_schema(schema) + assert "Parquet file contains partial field-ids" in str(exc_info.value) + assert warn.called