Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@
ONE_MEGABYTE = 1024 * 1024
BUFFER_SIZE = "buffer-size"
ICEBERG_SCHEMA = b"iceberg.schema"
FIELD_ID = "field_id"
DOC = "doc"
PYARROW_FIELD_ID_KEYS = [b"PARQUET:field_id", b"field_id"]
PYARROW_FIELD_DOC_KEYS = [b"PARQUET:field_doc", b"field_doc", b"doc"]
# The PARQUET: in front means that it is Parquet specific, in this case the field_id
PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"
PYARROW_FIELD_DOC_KEY = b"doc"

T = TypeVar("T")

Expand Down Expand Up @@ -461,7 +460,9 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
name=field.name,
type=field_result,
nullable=field.optional,
metadata={DOC: field.doc, FIELD_ID: str(field.field_id)} if field.doc else {FIELD_ID: str(field.field_id)},
metadata={PYARROW_FIELD_DOC_KEY: field.doc, PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}
if field.doc
else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)},
)

def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
Expand Down Expand Up @@ -725,25 +726,19 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]:


def _get_field_id(field: pa.Field) -> Optional[int]:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
return None


def _get_field_doc(field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
return None
return (
int(field_id_str.decode())
if (field.metadata and (field_id_str := field.metadata.get(PYARROW_PARQUET_FIELD_ID_KEY)))
else None
)


class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
field_doc = _get_field_doc(field)
field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None
field_type = field_results[i]
if field_type is not None and field_id is not None:
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
Expand Down
54 changes: 27 additions & 27 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,57 +324,57 @@ def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested)
expected = """foo: string
-- field metadata --
field_id: '1'
PARQUET:field_id: '1'
bar: int32 not null
-- field metadata --
field_id: '2'
PARQUET:field_id: '2'
baz: bool
-- field metadata --
field_id: '3'
PARQUET:field_id: '3'
qux: list<element: string not null> not null
child 0, element: string not null
-- field metadata --
field_id: '5'
PARQUET:field_id: '5'
-- field metadata --
field_id: '4'
PARQUET:field_id: '4'
quux: map<string, map<string, int32>> not null
child 0, entries: struct<key: string not null, value: map<string, int32> not null> not null
child 0, key: string not null
-- field metadata --
field_id: '7'
PARQUET:field_id: '7'
child 1, value: map<string, int32> not null
child 0, entries: struct<key: string not null, value: int32 not null> not null
child 0, key: string not null
-- field metadata --
field_id: '9'
PARQUET:field_id: '9'
child 1, value: int32 not null
-- field metadata --
field_id: '10'
PARQUET:field_id: '10'
-- field metadata --
field_id: '8'
PARQUET:field_id: '8'
-- field metadata --
field_id: '6'
PARQUET:field_id: '6'
location: list<element: struct<latitude: float, longitude: float> not null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
-- field metadata --
field_id: '13'
PARQUET:field_id: '13'
child 1, longitude: float
-- field metadata --
field_id: '14'
PARQUET:field_id: '14'
-- field metadata --
field_id: '12'
PARQUET:field_id: '12'
-- field metadata --
field_id: '11'
PARQUET:field_id: '11'
person: struct<name: string, age: int32 not null>
child 0, name: string
-- field metadata --
field_id: '16'
PARQUET:field_id: '16'
child 1, age: int32 not null
-- field metadata --
field_id: '17'
PARQUET:field_id: '17'
-- field metadata --
field_id: '15'"""
PARQUET:field_id: '15'"""
assert repr(actual) == expected


Expand Down Expand Up @@ -888,22 +888,22 @@ def test_projection_add_column(file_int: str) -> None:
list: list<element: int32>
child 0, element: int32
-- field metadata --
field_id: '21'
PARQUET:field_id: '21'
map: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
-- field metadata --
field_id: '31'
PARQUET:field_id: '31'
child 1, value: string
-- field metadata --
field_id: '32'
PARQUET:field_id: '32'
location: struct<lat: double, lon: double>
child 0, lat: double
-- field metadata --
field_id: '41'
PARQUET:field_id: '41'
child 1, lon: double
-- field metadata --
field_id: '42'"""
PARQUET:field_id: '42'"""
)


Expand Down Expand Up @@ -953,10 +953,10 @@ def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
-- field metadata --
field_id: '3'
PARQUET:field_id: '3'
child 1, value: string
-- field metadata --
field_id: '4'"""
PARQUET:field_id: '4'"""
)


Expand Down Expand Up @@ -1004,7 +1004,7 @@ def test_projection_filter(schema_int: Schema, file_int: str) -> None:
repr(result_table.schema)
== """id: int32
-- field metadata --
field_id: '1'"""
PARQUET:field_id: '1'"""
)


Expand Down Expand Up @@ -1182,10 +1182,10 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
== """location: struct<lat: double, long: double>
child 0, lat: double
-- field metadata --
field_id: '41'
PARQUET:field_id: '41'
child 1, long: double
-- field metadata --
field_id: '42'"""
PARQUET:field_id: '42'"""
)


Expand Down
12 changes: 6 additions & 6 deletions tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def test_pyarrow_variable_binary_to_iceberg() -> None:

def test_pyarrow_struct_to_iceberg() -> None:
pyarrow_struct = pa.struct([
pa.field("foo", pa.string(), nullable=True, metadata={"field_id": "1", "doc": "foo doc"}),
pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}),
pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}),
pa.field("foo", pa.string(), nullable=True, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}),
pa.field("bar", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "2"}),
pa.field("baz", pa.bool_(), nullable=True, metadata={"PARQUET:field_id": "3"}),
])
expected = StructType(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False, doc="foo doc"),
Expand All @@ -221,7 +221,7 @@ def test_pyarrow_struct_to_iceberg() -> None:


def test_pyarrow_list_to_iceberg() -> None:
pyarrow_list = pa.list_(pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"}))
pyarrow_list = pa.list_(pa.field("element", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "1"}))
expected = ListType(
element_id=1,
element_type=IntegerType(),
Expand All @@ -232,8 +232,8 @@ def test_pyarrow_list_to_iceberg() -> None:

def test_pyarrow_map_to_iceberg() -> None:
pyarrow_map = pa.map_(
pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}),
pa.field("value", pa.string(), nullable=False, metadata={"field_id": "2"}),
pa.field("key", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "1"}),
pa.field("value", pa.string(), nullable=False, metadata={"PARQUET:field_id": "2"}),
)
expected = MapType(
key_id=1,
Expand Down