Skip to content

Commit 2a27f2b

Browse files
authored
Arrow: Don't copy the list/map when not needed (#252)
1 parent 8464d71 commit 2a27f2b

File tree

4 files changed

+183
-32
lines changed

4 files changed

+183
-32
lines changed

dev/provision.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,25 @@
320320
spark.sql(f"ALTER TABLE {catalog_name}.default.test_table_add_column ADD COLUMN b string")
321321

322322
spark.sql(f"INSERT INTO {catalog_name}.default.test_table_add_column VALUES ('2', '2')")
323+
324+
spark.sql(
325+
f"""
326+
CREATE TABLE {catalog_name}.default.test_table_empty_list_and_map (
327+
col_list array<int>,
328+
col_map map<int, int>,
329+
col_list_with_struct array<struct<test:int>>
330+
)
331+
USING iceberg
332+
TBLPROPERTIES (
333+
'format-version'='1'
334+
);
335+
"""
336+
)
337+
338+
spark.sql(
339+
f"""
340+
INSERT INTO {catalog_name}.default.test_table_empty_list_and_map
341+
VALUES (null, null, null),
342+
(array(), map(), array(struct(1)))
343+
"""
344+
)

pyiceberg/io/pyarrow.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
LIST_ELEMENT_NAME = "element"
169169
MAP_KEY_NAME = "key"
170170
MAP_VALUE_NAME = "value"
171+
DOC = "doc"
171172

172173
T = TypeVar("T")
173174

@@ -1118,12 +1119,20 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
11181119
def __init__(self, file_schema: Schema):
11191120
self.file_schema = file_schema
11201121

1121-
def cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
1122+
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
11221123
file_field = self.file_schema.find_field(field.field_id)
11231124
if field.field_type.is_primitive and field.field_type != file_field.field_type:
11241125
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
11251126
return values
11261127

1128+
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
1129+
return pa.field(
1130+
name=field.name,
1131+
type=arrow_type,
1132+
nullable=field.optional,
1133+
metadata={DOC: field.doc} if field.doc is not None else None,
1134+
)
1135+
11271136
def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
11281137
return struct_result
11291138

@@ -1136,13 +1145,13 @@ def struct(
11361145
fields: List[pa.Field] = []
11371146
for field, field_array in zip(struct.fields, field_results):
11381147
if field_array is not None:
1139-
array = self.cast_if_needed(field, field_array)
1148+
array = self._cast_if_needed(field, field_array)
11401149
field_arrays.append(array)
1141-
fields.append(pa.field(field.name, array.type, field.optional))
1150+
fields.append(self._construct_field(field, array.type))
11421151
elif field.optional:
11431152
arrow_type = schema_to_pyarrow(field.field_type)
11441153
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
1145-
fields.append(pa.field(field.name, arrow_type, field.optional))
1154+
fields.append(self._construct_field(field, arrow_type))
11461155
else:
11471156
raise ResolveError(f"Field is required, and could not be found in the file: {field}")
11481157

@@ -1152,24 +1161,32 @@ def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional
11521161
return field_array
11531162

11541163
def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
1155-
return (
1156-
pa.ListArray.from_arrays(list_array.offsets, self.cast_if_needed(list_type.element_field, value_array))
1157-
if isinstance(list_array, pa.ListArray)
1158-
else None
1159-
)
1164+
if isinstance(list_array, pa.ListArray) and value_array is not None:
1165+
if isinstance(value_array, pa.StructArray):
1166+
# This can be removed once this has been fixed:
1167+
# https://github.com/apache/arrow/issues/38809
1168+
list_array = pa.ListArray.from_arrays(list_array.offsets, value_array)
1169+
1170+
arrow_field = pa.list_(self._construct_field(list_type.element_field, value_array.type))
1171+
return list_array.cast(arrow_field)
1172+
else:
1173+
return None
11601174

11611175
def map(
11621176
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
11631177
) -> Optional[pa.Array]:
1164-
return (
1165-
pa.MapArray.from_arrays(
1166-
map_array.offsets,
1167-
self.cast_if_needed(map_type.key_field, key_result),
1168-
self.cast_if_needed(map_type.value_field, value_result),
1178+
if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None:
1179+
arrow_field = pa.map_(
1180+
self._construct_field(map_type.key_field, key_result.type),
1181+
self._construct_field(map_type.value_field, value_result.type),
11691182
)
1170-
if isinstance(map_array, pa.MapArray)
1171-
else None
1172-
)
1183+
if isinstance(value_result, pa.StructArray):
1184+
# Arrow does not allow reordering of fields, therefore we have to copy the array :(
1185+
return pa.MapArray.from_arrays(map_array.offsets, key_result, value_result, arrow_field)
1186+
else:
1187+
return map_array.cast(arrow_field)
1188+
else:
1189+
return None
11731190

11741191
def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.Array]:
11751192
return array

tests/integration/test_reads.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,16 @@ def test_sanitize_character(catalog: Catalog) -> None:
428428
assert len(arrow_table.schema.names), 1
429429
assert len(table_test_table_sanitized_character.schema().fields), 1
430430
assert arrow_table.schema.names[0] == table_test_table_sanitized_character.schema().fields[0].name
431+
432+
433+
@pytest.mark.integration
434+
@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')])
435+
def test_null_list_and_map(catalog: Catalog) -> None:
436+
table_test_empty_list_and_map = catalog.load_table("default.test_table_empty_list_and_map")
437+
arrow_table = table_test_empty_list_and_map.scan().to_arrow()
438+
assert arrow_table["col_list"].to_pylist() == [None, []]
439+
assert arrow_table["col_map"].to_pylist() == [None, []]
440+
# This should be:
441+
# assert arrow_table["col_list_with_struct"].to_pylist() == [None, [{'test': 1}]]
442+
# Once https://github.com/apache/arrow/issues/38809 has been fixed
443+
assert arrow_table["col_list_with_struct"].to_pylist() == [[], [{'test': 1}]]

tests/io/test_pyarrow.py

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,24 @@ def schema_list_of_structs() -> Schema:
682682
)
683683

684684

685+
@pytest.fixture
686+
def schema_map_of_structs() -> Schema:
687+
return Schema(
688+
NestedField(
689+
5,
690+
"locations",
691+
MapType(
692+
key_id=51,
693+
value_id=52,
694+
key_type=StringType(),
695+
value_type=StructType(NestedField(511, "lat", DoubleType()), NestedField(512, "long", DoubleType())),
696+
element_required=False,
697+
),
698+
required=False,
699+
),
700+
)
701+
702+
685703
@pytest.fixture
686704
def schema_map() -> Schema:
687705
return Schema(
@@ -793,6 +811,25 @@ def file_list_of_structs(schema_list_of_structs: Schema, tmpdir: str) -> str:
793811
)
794812

795813

814+
@pytest.fixture
815+
def file_map_of_structs(schema_map_of_structs: Schema, tmpdir: str) -> str:
816+
pyarrow_schema = schema_to_pyarrow(
817+
schema_map_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_map_of_structs.model_dump_json(), UTF8)}
818+
)
819+
return _write_table_to_file(
820+
f"file:{tmpdir}/e.parquet",
821+
pyarrow_schema,
822+
pa.Table.from_pylist(
823+
[
824+
{"locations": {"1": {"lat": 52.371807, "long": 4.896029}, "2": {"lat": 52.387386, "long": 4.646219}}},
825+
{"locations": {}},
826+
{"locations": {"3": {"lat": 52.078663, "long": 4.288788}, "4": {"lat": 52.387386, "long": 4.646219}}},
827+
],
828+
schema=pyarrow_schema,
829+
),
830+
)
831+
832+
796833
@pytest.fixture
797834
def file_map(schema_map: Schema, tmpdir: str) -> str:
798835
pyarrow_schema = schema_to_pyarrow(schema_map, metadata={ICEBERG_SCHEMA: bytes(schema_map.model_dump_json(), UTF8)})
@@ -914,7 +951,11 @@ def test_read_list(schema_list: Schema, file_list: str) -> None:
914951
for actual, expected in zip(result_table.columns[0], [list(range(1, 10)), list(range(2, 20)), list(range(3, 30))]):
915952
assert actual.as_py() == expected
916953

917-
assert repr(result_table.schema) == "ids: list<item: int32>\n child 0, item: int32"
954+
assert (
955+
repr(result_table.schema)
956+
== """ids: list<element: int32>
957+
child 0, element: int32"""
958+
)
918959

919960

920961
def test_read_map(schema_map: Schema, file_map: str) -> None:
@@ -927,9 +968,9 @@ def test_read_map(schema_map: Schema, file_map: str) -> None:
927968
assert (
928969
repr(result_table.schema)
929970
== """properties: map<string, string>
930-
child 0, entries: struct<key: string not null, value: string> not null
971+
child 0, entries: struct<key: string not null, value: string not null> not null
931972
child 0, key: string not null
932-
child 1, value: string"""
973+
child 1, value: string not null"""
933974
)
934975

935976

@@ -1063,7 +1104,11 @@ def test_projection_nested_struct_subset(file_struct: str) -> None:
10631104
assert actual.as_py() == {"lat": expected}
10641105

10651106
assert len(result_table.columns[0]) == 3
1066-
assert repr(result_table.schema) == "location: struct<lat: double not null> not null\n child 0, lat: double not null"
1107+
assert (
1108+
repr(result_table.schema)
1109+
== """location: struct<lat: double not null> not null
1110+
child 0, lat: double not null"""
1111+
)
10671112

10681113

10691114
def test_projection_nested_new_field(file_struct: str) -> None:
@@ -1082,7 +1127,11 @@ def test_projection_nested_new_field(file_struct: str) -> None:
10821127
for actual, expected in zip(result_table.columns[0], [None, None, None]):
10831128
assert actual.as_py() == {"null": expected}
10841129
assert len(result_table.columns[0]) == 3
1085-
assert repr(result_table.schema) == "location: struct<null: double> not null\n child 0, null: double"
1130+
assert (
1131+
repr(result_table.schema)
1132+
== """location: struct<null: double> not null
1133+
child 0, null: double"""
1134+
)
10861135

10871136

10881137
def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> None:
@@ -1111,7 +1160,10 @@ def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> No
11111160
assert len(result_table.columns[0]) == 3
11121161
assert (
11131162
repr(result_table.schema)
1114-
== "location: struct<lat: double, null: double, long: double> not null\n child 0, lat: double\n child 1, null: double\n child 2, long: double"
1163+
== """location: struct<lat: double, null: double, long: double> not null
1164+
child 0, lat: double
1165+
child 1, null: double
1166+
child 2, long: double"""
11151167
)
11161168

11171169

@@ -1136,28 +1188,75 @@ def test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of
11361188
result_table = project(schema, [file_list_of_structs])
11371189
assert len(result_table.columns) == 1
11381190
assert len(result_table.columns[0]) == 3
1191+
results = [row.as_py() for row in result_table.columns[0]]
1192+
assert results == [
1193+
[
1194+
{'latitude': 52.371807, 'longitude': 4.896029, 'altitude': None},
1195+
{'latitude': 52.387386, 'longitude': 4.646219, 'altitude': None},
1196+
],
1197+
[],
1198+
[
1199+
{'latitude': 52.078663, 'longitude': 4.288788, 'altitude': None},
1200+
{'latitude': 52.387386, 'longitude': 4.646219, 'altitude': None},
1201+
],
1202+
]
1203+
assert (
1204+
repr(result_table.schema)
1205+
== """locations: list<element: struct<latitude: double not null, longitude: double not null, altitude: double>>
1206+
child 0, element: struct<latitude: double not null, longitude: double not null, altitude: double>
1207+
child 0, latitude: double not null
1208+
child 1, longitude: double not null
1209+
child 2, altitude: double"""
1210+
)
1211+
1212+
1213+
def test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_structs: str) -> None:
1214+
schema = Schema(
1215+
NestedField(
1216+
5,
1217+
"locations",
1218+
MapType(
1219+
key_id=51,
1220+
value_id=52,
1221+
key_type=StringType(),
1222+
value_type=StructType(
1223+
NestedField(511, "latitude", DoubleType()),
1224+
NestedField(512, "longitude", DoubleType()),
1225+
NestedField(513, "altitude", DoubleType(), required=False),
1226+
),
1227+
element_required=False,
1228+
),
1229+
required=False,
1230+
),
1231+
)
1232+
1233+
result_table = project(schema, [file_map_of_structs])
1234+
assert len(result_table.columns) == 1
1235+
assert len(result_table.columns[0]) == 3
11391236
for actual, expected in zip(
11401237
result_table.columns[0],
11411238
[
11421239
[
1143-
{"latitude": 52.371807, "longitude": 4.896029, "altitude": None},
1144-
{"latitude": 52.387386, "longitude": 4.646219, "altitude": None},
1240+
("1", {"latitude": 52.371807, "longitude": 4.896029, "altitude": None}),
1241+
("2", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}),
11451242
],
11461243
[],
11471244
[
1148-
{"latitude": 52.078663, "longitude": 4.288788, "altitude": None},
1149-
{"latitude": 52.387386, "longitude": 4.646219, "altitude": None},
1245+
("3", {"latitude": 52.078663, "longitude": 4.288788, "altitude": None}),
1246+
("4", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}),
11501247
],
11511248
],
11521249
):
11531250
assert actual.as_py() == expected
11541251
assert (
11551252
repr(result_table.schema)
1156-
== """locations: list<item: struct<latitude: double not null, longitude: double not null, altitude: double>>
1157-
child 0, item: struct<latitude: double not null, longitude: double not null, altitude: double>
1158-
child 0, latitude: double not null
1159-
child 1, longitude: double not null
1160-
child 2, altitude: double"""
1253+
== """locations: map<string, struct<latitude: double not null, longitude: double not null, altitude: double>>
1254+
child 0, entries: struct<key: string not null, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null> not null
1255+
child 0, key: string not null
1256+
child 1, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null
1257+
child 0, latitude: double not null
1258+
child 1, longitude: double not null
1259+
child 2, altitude: double"""
11611260
)
11621261

11631262

0 commit comments

Comments
 (0)