Skip to content

Commit cd7fb50

Browse files
FokkoHonahX
andauthored
Add UnionByName functionality (#296)
* Add UnionByName functionality * Thanks Honah! * Add `_id` Co-authored-by: Honah J. <[email protected]> * Fix --------- Co-authored-by: Honah J. <[email protected]>
1 parent 2a27f2b commit cd7fb50

File tree

4 files changed

+905
-7
lines changed

4 files changed

+905
-7
lines changed

mkdocs/docs/api.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,47 @@ with table.transaction() as transaction:
293293
# ... Update properties etc
294294
```
295295

296+
### Union by Name
297+
298+
Using `.union_by_name()` you can merge another schema into an existing schema without having to worry about field-IDs:
299+
300+
```python
301+
from pyiceberg.catalog import load_catalog
302+
from pyiceberg.schema import Schema
303+
from pyiceberg.types import NestedField, StringType, DoubleType, LongType
304+
305+
catalog = load_catalog()
306+
307+
schema = Schema(
308+
NestedField(1, "city", StringType(), required=False),
309+
NestedField(2, "lat", DoubleType(), required=False),
310+
NestedField(3, "long", DoubleType(), required=False),
311+
)
312+
313+
table = catalog.create_table("default.locations", schema)
314+
315+
new_schema = Schema(
316+
NestedField(1, "city", StringType(), required=False),
317+
NestedField(2, "lat", DoubleType(), required=False),
318+
NestedField(3, "long", DoubleType(), required=False),
319+
NestedField(10, "population", LongType(), required=False),
320+
)
321+
322+
with table.update_schema() as update:
323+
update.union_by_name(new_schema)
324+
```
325+
326+
Now the table has the union of the two schemas `print(table.schema())`:
327+
328+
```
329+
table {
330+
1: city: optional string
331+
2: lat: optional double
332+
3: long: optional double
333+
4: population: optional long
334+
}
335+
```
336+
296337
### Add column
297338

298339
Using `add_column` you can add a column, without having to worry about the field-id:

pyiceberg/table/__init__.py

Lines changed: 183 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@
6969
)
7070
from pyiceberg.partitioning import PartitionSpec
7171
from pyiceberg.schema import (
72+
PartnerAccessor,
7273
Schema,
7374
SchemaVisitor,
75+
SchemaWithPartnerVisitor,
7476
assign_fresh_schema_ids,
7577
promote,
7678
visit,
79+
visit_with_partner,
7780
)
7881
from pyiceberg.table.metadata import (
7982
INITIAL_SEQUENCE_NUMBER,
@@ -1379,7 +1382,7 @@ class Move:
13791382

13801383

13811384
class UpdateSchema:
1382-
_table: Table
1385+
_table: Optional[Table]
13831386
_schema: Schema
13841387
_last_column_id: itertools.count[int]
13851388
_identifier_field_names: Set[str]
@@ -1398,14 +1401,23 @@ class UpdateSchema:
13981401

13991402
def __init__(
14001403
self,
1401-
table: Table,
1404+
table: Optional[Table],
14021405
transaction: Optional[Transaction] = None,
14031406
allow_incompatible_changes: bool = False,
14041407
case_sensitive: bool = True,
1408+
schema: Optional[Schema] = None,
14051409
) -> None:
14061410
self._table = table
1407-
self._schema = table.schema()
1408-
self._last_column_id = itertools.count(table.metadata.last_column_id + 1)
1411+
1412+
if isinstance(schema, Schema):
1413+
self._schema = schema
1414+
self._last_column_id = itertools.count(1 + schema.highest_field_id)
1415+
elif table is not None:
1416+
self._schema = table.schema()
1417+
self._last_column_id = itertools.count(1 + table.metadata.last_column_id)
1418+
else:
1419+
raise ValueError("Either provide a table or a schema")
1420+
14091421
self._identifier_field_names = self._schema.identifier_field_names()
14101422

14111423
self._adds = {}
@@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
14491461
self._case_sensitive = case_sensitive
14501462
return self
14511463

1464+
def union_by_name(self, new_schema: Schema) -> UpdateSchema:
1465+
visit_with_partner(
1466+
new_schema,
1467+
-1,
1468+
UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore
1469+
PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive),
1470+
)
1471+
return self
1472+
14521473
def add_column(
14531474
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False
14541475
) -> UpdateSchema:
@@ -1816,6 +1837,9 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T
18161837

18171838
def commit(self) -> None:
18181839
"""Apply the pending changes and commit."""
1840+
if self._table is None:
1841+
raise ValueError("Requires a table to commit to")
1842+
18191843
new_schema = self._apply()
18201844

18211845
existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None)
@@ -1862,7 +1886,8 @@ def _apply(self) -> Schema:
18621886

18631887
field_ids.add(field.field_id)
18641888

1865-
return Schema(*struct.fields, schema_id=1 + max(self._table.schemas().keys()), identifier_field_ids=field_ids)
1889+
next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table is not None else self._schema.schema_id)
1890+
return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids)
18661891

18671892
def assign_new_column_id(self) -> int:
18681893
return next(self._last_column_id)
@@ -1995,6 +2020,159 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
19952020
return primitive
19962021

19972022

2023+
class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
2024+
update_schema: UpdateSchema
2025+
existing_schema: Schema
2026+
case_sensitive: bool
2027+
2028+
def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None:
2029+
self.update_schema = update_schema
2030+
self.existing_schema = existing_schema
2031+
self.case_sensitive = case_sensitive
2032+
2033+
def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool:
2034+
return struct_result
2035+
2036+
def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool:
2037+
if partner_id is None:
2038+
return True
2039+
2040+
fields = struct.fields
2041+
partner_struct = self._find_field_type(partner_id)
2042+
2043+
if not partner_struct.is_struct:
2044+
raise ValueError(f"Expected a struct, got: {partner_struct}")
2045+
2046+
for pos, missing in enumerate(missing_positions):
2047+
if missing:
2048+
self._add_column(partner_id, fields[pos])
2049+
else:
2050+
field = fields[pos]
2051+
if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive):
2052+
self._update_column(field, nested_field)
2053+
2054+
return False
2055+
2056+
def _add_column(self, parent_id: int, field: NestedField) -> None:
2057+
if parent_name := self.existing_schema.find_column_name(parent_id):
2058+
path: Tuple[str, ...] = (parent_name, field.name)
2059+
else:
2060+
path = (field.name,)
2061+
2062+
self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc)
2063+
2064+
def _update_column(self, field: NestedField, existing_field: NestedField) -> None:
2065+
full_name = self.existing_schema.find_column_name(existing_field.field_id)
2066+
2067+
if full_name is None:
2068+
raise ValueError(f"Could not find field: {existing_field}")
2069+
2070+
if field.optional and existing_field.required:
2071+
self.update_schema.make_column_optional(full_name)
2072+
2073+
if field.field_type.is_primitive and field.field_type != existing_field.field_type:
2074+
self.update_schema.update_column(full_name, field_type=field.field_type)
2075+
2076+
if field.doc is not None and not field.doc != existing_field.doc:
2077+
self.update_schema.update_column(full_name, doc=field.doc)
2078+
2079+
def _find_field_type(self, field_id: int) -> IcebergType:
2080+
if field_id == -1:
2081+
return self.existing_schema.as_struct()
2082+
else:
2083+
return self.existing_schema.find_field(field_id).field_type
2084+
2085+
def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool:
2086+
return partner_id is None
2087+
2088+
def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool:
2089+
if list_partner_id is None:
2090+
return True
2091+
2092+
if element_missing:
2093+
raise ValueError("Error traversing schemas: element is missing, but list is present")
2094+
2095+
partner_list_type = self._find_field_type(list_partner_id)
2096+
if not isinstance(partner_list_type, ListType):
2097+
raise ValueError(f"Expected list-type, got: {partner_list_type}")
2098+
2099+
self._update_column(list_type.element_field, partner_list_type.element_field)
2100+
2101+
return False
2102+
2103+
def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool:
2104+
if map_partner_id is None:
2105+
return True
2106+
2107+
if key_missing:
2108+
raise ValueError("Error traversing schemas: key is missing, but map is present")
2109+
2110+
if value_missing:
2111+
raise ValueError("Error traversing schemas: value is missing, but map is present")
2112+
2113+
partner_map_type = self._find_field_type(map_partner_id)
2114+
if not isinstance(partner_map_type, MapType):
2115+
raise ValueError(f"Expected map-type, got: {partner_map_type}")
2116+
2117+
self._update_column(map_type.key_field, partner_map_type.key_field)
2118+
self._update_column(map_type.value_field, partner_map_type.value_field)
2119+
2120+
return False
2121+
2122+
def primitive(self, primitive: PrimitiveType, primitive_partner_id: Optional[int]) -> bool:
2123+
return primitive_partner_id is None
2124+
2125+
2126+
class PartnerIdByNameAccessor(PartnerAccessor[int]):
2127+
partner_schema: Schema
2128+
case_sensitive: bool
2129+
2130+
def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None:
2131+
self.partner_schema = partner_schema
2132+
self.case_sensitive = case_sensitive
2133+
2134+
def schema_partner(self, partner: Optional[int]) -> Optional[int]:
2135+
return -1
2136+
2137+
def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]:
2138+
if partner_field_id is not None:
2139+
if partner_field_id == -1:
2140+
struct = self.partner_schema.as_struct()
2141+
else:
2142+
struct = self.partner_schema.find_field(partner_field_id).field_type
2143+
if not struct.is_struct:
2144+
raise ValueError(f"Expected StructType: {struct}")
2145+
2146+
if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive):
2147+
return field.field_id
2148+
2149+
return None
2150+
2151+
def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]:
2152+
if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)):
2153+
if not isinstance(field.field_type, ListType):
2154+
raise ValueError(f"Expected ListType: {field}")
2155+
return field.field_type.element_field.field_id
2156+
else:
2157+
return None
2158+
2159+
def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
2160+
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
2161+
if not isinstance(field.field_type, MapType):
2162+
raise ValueError(f"Expected MapType: {field}")
2163+
return field.field_type.key_field.field_id
2164+
else:
2165+
return None
2166+
2167+
def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
2168+
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
2169+
if not isinstance(field.field_type, MapType):
2170+
raise ValueError(f"Expected MapType: {field}")
2171+
return field.field_type.value_field.field_id
2172+
else:
2173+
return None
2174+
2175+
19982176
def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]:
19992177
adds = adds or []
20002178
return fields + tuple(adds)

pyiceberg/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,18 @@ def field(self, field_id: int) -> Optional[NestedField]:
350350
return field
351351
return None
352352

353+
def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]:
354+
if case_sensitive:
355+
name_lower = name.lower()
356+
for field in self.fields:
357+
if field.name.lower() == name_lower:
358+
return field
359+
else:
360+
for field in self.fields:
361+
if field.name == name:
362+
return field
363+
return None
364+
353365
def __str__(self) -> str:
354366
"""Return the string representation of the StructType class."""
355367
return f"struct<{', '.join(map(str, self.fields))}>"

0 commit comments

Comments
 (0)